diff --git a/.gitignore b/.gitignore index cfdf0275..1724bd43 100644 --- a/.gitignore +++ b/.gitignore @@ -24,5 +24,7 @@ tools/montreal-forced-aligner/ tools/Montreal-Forced-Aligner/ tools/sctk tools/sctk-20159b5/ +tools/kaldi +tools/OpenBLAS/ *output/ diff --git a/deepspeech/decoders/recog.py b/deepspeech/decoders/recog.py index 6dea6b70..d1ddfc8a 100644 --- a/deepspeech/decoders/recog.py +++ b/deepspeech/decoders/recog.py @@ -24,6 +24,7 @@ from .utils import add_results_to_json from deepspeech.exps import dynamic_import_tester from deepspeech.io.reader import LoadInputsAndTargets from deepspeech.models.asr_interface import ASRInterface +from deepspeech.models.lm.transformer import TransformerLM from deepspeech.utils.log import Log # from espnet.asr.asr_utils import get_model_conf # from espnet.asr.asr_utils import torch_load @@ -48,6 +49,21 @@ def load_trained_model(args): model = exp.model return model, char_list, exp, confs +def get_config(config_path): + stream = open(config_path, mode='r', encoding="utf-8") + config = yaml.load(stream, Loader=yaml.FullLoader) + stream.close() + return config + +def load_trained_lm(args): + lm_args = get_config(args.rnnlm_conf) + # NOTE: for a compatibility with less than 0.5.0 version models + lm_model_module = getattr(lm_args, "model_module", "default") + lm_class = dynamic_import_lm(lm_model_module) + lm = lm_class(lm_args.model) + model_dict = paddle.load(args.rnnlm) + lm.set_state_dict(model_dict) + return lm def recog_v2(args): """Decode with custom models that implements ScorerInterface. @@ -78,12 +94,7 @@ def recog_v2(args): preprocess_args={"train": False}, ) if args.rnnlm: - lm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) - # NOTE: for a compatibility with less than 0.5.0 version models - lm_model_module = getattr(lm_args, "model_module", "default") - lm_class = dynamic_import_lm(lm_model_module, lm_args.backend) - lm = lm_class(len(char_list), lm_args) - torch_load(args.rnnlm, lm) + lm = load_trained_lm(args) lm.eval() else: lm = None diff --git a/deepspeech/decoders/recog_bin.py b/deepspeech/decoders/recog_bin.py index fbf582f7..cb3d5757 100644 --- a/deepspeech/decoders/recog_bin.py +++ b/deepspeech/decoders/recog_bin.py @@ -21,9 +21,6 @@ from distutils.util import strtobool import configargparse import numpy as np -from deepspeech.decoders.recog import recog_v2 - - def get_parser(): """Get default arguments.""" parser = configargparse.ArgumentParser( @@ -359,7 +356,7 @@ def main(args): if args.num_encs == 1: # Experimental API that supports custom LMs if args.api == "v2": - + from deepspeech.decoders.recog import recog_v2 recog_v2(args) else: raise ValueError("Only support --api v2") diff --git a/deepspeech/decoders/scorers/ctc_prefix_score.py b/deepspeech/decoders/scorers/ctc_prefix_score.py index c85d546d..13429d49 100644 --- a/deepspeech/decoders/scorers/ctc_prefix_score.py +++ b/deepspeech/decoders/scorers/ctc_prefix_score.py @@ -318,6 +318,18 @@ class CTCPrefixScore(): r[0, 0] = xs[0] r[0, 1] = self.logzero else: + # Although the code does not exactly follow Algorithm 2, + # we don't have to change it because we can assume + # r_t(h)=0 for t < |h| in CTC forward computation + # (Note: we assume here that index t starts with 0). + # The purpose of this difference is to reduce the number of for-loops. + # https://github.com/espnet/espnet/pull/3655 + # where we start to accumulate r_t(h) from t=|h| + # and iterate r_t(h) = (r_{t-1}(h) + ...) to T-1, + # avoiding accumulating zeros for t=1~|h|-1. + # Thus, we need to set r_{|h|-1}(h) = 0, + # i.e., r[output_length-1] = logzero, for initialization. + # This is just for reducing the computation. r[output_length - 1] = self.logzero # prepare forward probabilities for the last label diff --git a/deepspeech/frontend/augmentor/augmentation.py b/deepspeech/frontend/augmentor/augmentation.py index 0de81333..d2316ab1 100644 --- a/deepspeech/frontend/augmentor/augmentation.py +++ b/deepspeech/frontend/augmentor/augmentation.py @@ -13,6 +13,7 @@ # limitations under the License. """Contains the data augmentation pipeline.""" import json +import os from collections.abc import Sequence from inspect import signature from pprint import pformat @@ -90,9 +91,8 @@ class AugmentationPipeline(): effect. Params: - augmentation_config(str): Augmentation configuration in json string. + preprocess_conf(str): Augmentation configuration in `json file` or `json string`. random_seed(int): Random seed. - train(bool): whether is train mode. Raises: ValueError: If the augmentation json config is in incorrect format". @@ -100,11 +100,18 @@ class AugmentationPipeline(): SPEC_TYPES = {'specaug'} - def __init__(self, augmentation_config: str, random_seed: int=0): + def __init__(self, preprocess_conf: str, random_seed: int=0): self._rng = np.random.RandomState(random_seed) self.conf = {'mode': 'sequential', 'process': []} - if augmentation_config: - process = json.loads(augmentation_config) + if preprocess_conf: + if os.path.isfile(preprocess_conf): + # json file + with open(preprocess_conf, 'r') as fin: + json_string = fin.read() + else: + # json string + json_string = preprocess_conf + process = json.loads(json_string) self.conf['process'] += process self._augmentors, self._rates = self._parse_pipeline_from('all') diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index 5f0bc462..b523dfc8 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -105,7 +105,7 @@ class SpeechCollatorBase(): self._local_data = TarLocalData(tar2info={}, tar2object={}) self.augmentation = AugmentationPipeline( - augmentation_config=aug_file.read(), random_seed=random_seed) + preprocess_conf=aug_file.read(), random_seed=random_seed) self._normalizer = FeatureNormalizer( mean_std_filepath) if mean_std_filepath else None diff --git a/deepspeech/io/reader.py b/deepspeech/io/reader.py index 5873788b..59098752 100644 --- a/deepspeech/io/reader.py +++ b/deepspeech/io/reader.py @@ -17,7 +17,7 @@ import kaldiio import numpy as np import soundfile -from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline +from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline as Transformation from deepspeech.utils.log import Log __all__ = ["LoadInputsAndTargets"] @@ -66,8 +66,7 @@ class LoadInputsAndTargets(): raise ValueError("Only asr are allowed: mode={}".format(mode)) if preprocess_conf is not None: - with open(preprocess_conf, 'r') as fin: - self.preprocessing = AugmentationPipeline(fin.read()) + self.preprocessing = Transformation(preprocess_conf) logger.warning( "[Experimental feature] Some preprocessing will be done " "for the mini-batch creation using {}".format( diff --git a/deepspeech/models/asr_interface.py b/deepspeech/models/asr_interface.py index 7dac81b4..d86daa0b 100644 --- a/deepspeech/models/asr_interface.py +++ b/deepspeech/models/asr_interface.py @@ -18,7 +18,7 @@ from deepspeech.utils.dynamic_import import dynamic_import class ASRInterface: - """ASR Interface for ESPnet model implementation.""" + """ASR Interface model implementation.""" @staticmethod def add_arguments(parser): @@ -103,14 +103,14 @@ class ASRInterface: @property def attention_plot_class(self): """Get attention plot class.""" - from espnet.asr.asr_utils import PlotAttentionReport + from deepspeech.training.extensions.plot import PlotAttentionReport return PlotAttentionReport @property def ctc_plot_class(self): """Get CTC plot class.""" - from espnet.asr.asr_utils import PlotCTCReport + from deepspeech.training.extensions.plot import PlotCTCReport return PlotCTCReport diff --git a/deepspeech/models/lm/__init__.py b/deepspeech/models/lm/__init__.py new file mode 100644 index 00000000..185a92b8 --- /dev/null +++ b/deepspeech/models/lm/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/deepspeech/models/lm/transformer.py b/deepspeech/models/lm/transformer.py new file mode 100644 index 00000000..28371ae2 --- /dev/null +++ b/deepspeech/models/lm/transformer.py @@ -0,0 +1,263 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any +from typing import List +from typing import Tuple + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface +from deepspeech.models.lm_interface import LMInterface +from deepspeech.modules.encoder import TransformerEncoder +from deepspeech.modules.mask import subsequent_mask +from deepspeech.utils.log import Log + +logger = Log(__name__).getlog() + + +class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): + def __init__( + self, + n_vocab: int, + pos_enc: str=None, + embed_unit: int=128, + att_unit: int=256, + head: int=2, + unit: int=1024, + layer: int=4, + dropout_rate: float=0.5, + emb_dropout_rate: float=0.0, + att_dropout_rate: float=0.0, + tie_weights: bool=False, + **kwargs): + nn.Layer.__init__(self) + + if pos_enc == "sinusoidal": + pos_enc_layer_type = "abs_pos" + elif pos_enc is None: + pos_enc_layer_type = "no_pos" + else: + raise ValueError(f"unknown pos-enc option: {pos_enc}") + + self.embed = nn.Embedding(n_vocab, embed_unit) + + if emb_dropout_rate == 0.0: + self.embed_drop = None + else: + self.embed_drop = nn.Dropout(emb_dropout_rate) + + self.encoder = TransformerEncoder( + input_size=embed_unit, + output_size=att_unit, + attention_heads=head, + linear_units=unit, + num_blocks=layer, + dropout_rate=dropout_rate, + attention_dropout_rate=att_dropout_rate, + input_layer="linear", + pos_enc_layer_type=pos_enc_layer_type, + concat_after=False, + static_chunk_size=1, + use_dynamic_chunk=False, + use_dynamic_left_chunk=False) + + self.decoder = nn.Linear(att_unit, n_vocab) + + logger.info("Tie weights set to {}".format(tie_weights)) + logger.info("Dropout set to {}".format(dropout_rate)) + logger.info("Emb Dropout set to {}".format(emb_dropout_rate)) + logger.info("Att Dropout set to {}".format(att_dropout_rate)) + + if tie_weights: + assert ( + att_unit == embed_unit + ), "Tie Weights: True need embedding and final dimensions to match" + self.decoder.weight = self.embed.weight + + def _target_mask(self, ys_in_pad): + ys_mask = ys_in_pad != 0 + m = subsequent_mask(ys_mask.size(-1)).unsqueeze(0) + return ys_mask.unsqueeze(-2) & m + + def forward(self, x: paddle.Tensor, t: paddle.Tensor + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Compute LM loss value from buffer sequences. + + Args: + x (paddle.Tensor): Input ids. (batch, len) + t (paddle.Tensor): Target ids. (batch, len) + + Returns: + tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: Tuple of + loss to backward (scalar), + negative log-likelihood of t: -log p(t) (scalar) and + the number of elements in x (scalar) + + Notes: + The last two return values are used + in perplexity: p(t)^{-n} = exp(-log p(t) / n) + + """ + xm = x != 0 + xlen = xm.sum(axis=1) + if self.embed_drop is not None: + emb = self.embed_drop(self.embed(x)) + else: + emb = self.embed(x) + h, _ = self.encoder(emb, xlen) + y = self.decoder(h) + loss = F.cross_entropy( + y.view(-1, y.shape[-1]), t.view(-1), reduction="none") + mask = xm.to(dtype=loss.dtype) + logp = loss * mask.view(-1) + logp = logp.sum() + count = mask.sum() + return logp / count, logp, count + + # beam search API (see ScorerInterface) + def score(self, y: paddle.Tensor, state: Any, + x: paddle.Tensor) -> Tuple[paddle.Tensor, Any]: + """Score new token. + + Args: + y (paddle.Tensor): 1D paddle.int64 prefix tokens. + state: Scorer state for prefix tokens + x (paddle.Tensor): encoder feature that generates ys. + + Returns: + tuple[paddle.Tensor, Any]: Tuple of + paddle.float32 scores for next token (n_vocab) + and next state for ys + + """ + y = y.unsqueeze(0) + + if self.embed_drop is not None: + emb = self.embed_drop(self.embed(y)) + else: + emb = self.embed(y) + + h, _, cache = self.encoder.forward_one_step( + emb, self._target_mask(y), cache=state) + h = self.decoder(h[:, -1]) + logp = F.log_softmax(h).squeeze(0) + return logp, cache + + # batch beam search API (see BatchScorerInterface) + def batch_score(self, + ys: paddle.Tensor, + states: List[Any], + xs: paddle.Tensor) -> Tuple[paddle.Tensor, List[Any]]: + """Score new token batch (required). + + Args: + ys (paddle.Tensor): paddle.int64 prefix tokens (n_batch, ylen). + states (List[Any]): Scorer states for prefix tokens. + xs (paddle.Tensor): + The encoder feature that generates ys (n_batch, xlen, n_feat). + + Returns: + tuple[paddle.Tensor, List[Any]]: Tuple of + batchfied scores for next token with shape of `(n_batch, n_vocab)` + and next state list for ys. + + """ + # merge states + n_batch = len(ys) + n_layers = len(self.encoder.encoders) + if states[0] is None: + batch_state = None + else: + # transpose state of [batch, layer] into [layer, batch] + batch_state = [ + paddle.stack([states[b][i] for b in range(n_batch)]) + for i in range(n_layers) + ] + + if self.embed_drop is not None: + emb = self.embed_drop(self.embed(ys)) + else: + emb = self.embed(ys) + + # batch decoding + h, _, states = self.encoder.forward_one_step( + emb, self._target_mask(ys), cache=batch_state) + h = self.decoder(h[:, -1]) + logp = F.log_softmax(h) + + # transpose state of [layer, batch] into [batch, layer] + state_list = [[states[i][b] for i in range(n_layers)] + for b in range(n_batch)] + return logp, state_list + + +if __name__ == "__main__": + tlm = TransformerLM( + n_vocab=5002, + pos_enc=None, + embed_unit=128, + att_unit=512, + head=8, + unit=2048, + layer=16, + dropout_rate=0.5, ) + + # n_vocab: int, + # pos_enc: str=None, + # embed_unit: int=128, + # att_unit: int=256, + # head: int=2, + # unit: int=1024, + # layer: int=4, + # dropout_rate: float=0.5, + # emb_dropout_rate: float = 0.0, + # att_dropout_rate: float = 0.0, + # tie_weights: bool = False,): + paddle.set_device("cpu") + model_dict = paddle.load("transformerLM.pdparams") + tlm.set_state_dict(model_dict) + + tlm.eval() + #Test the score + input2 = np.array([5]) + input2 = paddle.to_tensor(input2) + state = None + output, state = tlm.score(input2, state, None) + + input3 = np.array([5, 10]) + input3 = paddle.to_tensor(input3) + output, state = tlm.score(input3, state, None) + + input4 = np.array([5, 10, 0]) + input4 = paddle.to_tensor(input4) + output, state = tlm.score(input4, state, None) + print("output", output) + """ + #Test the batch score + batch_size = 2 + inp2 = np.array([[5], [10]]) + inp2 = paddle.to_tensor(inp2) + output, states = tlm.batch_score( + inp2, [(None,None,0)] * batch_size) + inp3 = np.array([[100], [30]]) + inp3 = paddle.to_tensor(inp3) + output, states = tlm.batch_score( + inp3, states) + print("output", output) + #print("cache", cache) + #np.save("output_pd.npy", output) + """ diff --git a/deepspeech/models/lm_interface.py b/deepspeech/models/lm_interface.py new file mode 100644 index 00000000..e2987282 --- /dev/null +++ b/deepspeech/models/lm_interface.py @@ -0,0 +1,82 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Language model interface.""" +import argparse + +from deepspeech.decoders.scorers.scorer_interface import ScorerInterface +from deepspeech.utils.dynamic_import import dynamic_import + + +class LMInterface(ScorerInterface): + """LM Interface model implementation.""" + + @staticmethod + def add_arguments(parser): + """Add arguments to command line argument parser.""" + return parser + + @classmethod + def build(cls, n_vocab: int, **kwargs): + """Initialize this class with python-level args. + + Args: + idim (int): The number of vocabulary. + + Returns: + LMinterface: A new instance of LMInterface. + + """ + args = argparse.Namespace(**kwargs) + return cls(n_vocab, args) + + def forward(self, x, t): + """Compute LM loss value from buffer sequences. + + Args: + x (torch.Tensor): Input ids. (batch, len) + t (torch.Tensor): Target ids. (batch, len) + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of + loss to backward (scalar), + negative log-likelihood of t: -log p(t) (scalar) and + the number of elements in x (scalar) + + Notes: + The last two return values are used + in perplexity: p(t)^{-n} = exp(-log p(t) / n) + + """ + raise NotImplementedError("forward method is not implemented") + + +predefined_lms = { + "transformer": "deepspeech.models.lm.transformer:TransformerLM", +} + + +def dynamic_import_lm(module): + """Import LM class dynamically. + + Args: + module (str): module_name:class_name or alias in `predefined_lms` + + Returns: + type: LM class + + """ + model_class = dynamic_import(module, predefined_lms) + assert issubclass(model_class, + LMInterface), f"{module} does not implement LMInterface" + return model_class diff --git a/deepspeech/models/st_interface.py b/deepspeech/models/st_interface.py new file mode 100644 index 00000000..05939f9a --- /dev/null +++ b/deepspeech/models/st_interface.py @@ -0,0 +1,75 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ST Interface module.""" +from .asr_interface import ASRInterface +from deepspeech.utils.dynamic_import import dynamic_import + + +class STInterface(ASRInterface): + """ST Interface model implementation. + + NOTE: This class is inherited from ASRInterface to enable joint translation + and recognition when performing multi-task learning with the ASR task. + + """ + + def translate(self, + x, + trans_args, + char_list=None, + rnnlm=None, + ensemble_models=[]): + """Recognize x for evaluation. + + :param ndarray x: input acouctic feature (B, T, D) or (T, D) + :param namespace trans_args: argment namespace contraining options + :param list char_list: list of characters + :param paddle.nn.Layer rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + raise NotImplementedError("translate method is not implemented") + + def translate_batch(self, x, trans_args, char_list=None, rnnlm=None): + """Beam search implementation for batch. + + :param paddle.Tensor x: encoder hidden state sequences (B, Tmax, Henc) + :param namespace trans_args: argument namespace containing options + :param list char_list: list of characters + :param paddle.nn.Layer rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + raise NotImplementedError("Batch decoding is not supported yet.") + + +predefined_st = { + "transformer": "deepspeech.models.u2_st:U2STModel", +} + + +def dynamic_import_st(module): + """Import ST models dynamically. + + Args: + module (str): module_name:class_name or alias in `predefined_st` + + Returns: + type: ST class + + """ + model_class = dynamic_import(module, predefined_st) + assert issubclass(model_class, + STInterface), f"{module} does not implement STInterface" + return model_class diff --git a/deepspeech/models/u2_st/__init__.py b/deepspeech/models/u2_st/__init__.py new file mode 100644 index 00000000..6b10b083 --- /dev/null +++ b/deepspeech/models/u2_st/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .u2_st import U2STInferModel +from .u2_st import U2STModel diff --git a/deepspeech/models/u2_st.py b/deepspeech/models/u2_st/u2_st.py similarity index 100% rename from deepspeech/models/u2_st.py rename to deepspeech/models/u2_st/u2_st.py diff --git a/deepspeech/modules/embedding.py b/deepspeech/modules/embedding.py index fbbda023..64d594c2 100644 --- a/deepspeech/modules/embedding.py +++ b/deepspeech/modules/embedding.py @@ -22,10 +22,52 @@ from deepspeech.utils.log import Log logger = Log(__name__).getlog() -__all__ = ["PositionalEncoding", "RelPositionalEncoding"] +__all__ = [ + "PositionalEncodingInterface", "NoPositionalEncoding", "PositionalEncoding", + "RelPositionalEncoding" +] -class PositionalEncoding(nn.Layer): +class PositionalEncodingInterface: + def forward(self, x: paddle.Tensor, + offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Compute positional encoding. + Args: + x (paddle.Tensor): Input tensor (batch, time, `*`). + Returns: + paddle.Tensor: Encoded tensor (batch, time, `*`). + paddle.Tensor: Positional embedding tensor (1, time, `*`). + """ + raise NotImplementedError("forward method is not implemented") + + def position_encoding(self, offset: int, size: int) -> paddle.Tensor: + """ For getting encoding in a streaming fashion + Args: + offset (int): start offset + size (int): requried size of position encoding + Returns: + paddle.Tensor: Corresponding position encoding + """ + raise NotImplementedError("position_encoding method is not implemented") + + +class NoPositionalEncoding(nn.Layer, PositionalEncodingInterface): + def __init__(self, + d_model: int, + dropout_rate: float, + max_len: int=5000, + reverse: bool=False): + nn.Layer.__init__(self) + + def forward(self, x: paddle.Tensor, + offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]: + return x, None + + def position_encoding(self, offset: int, size: int) -> paddle.Tensor: + return None + + +class PositionalEncoding(nn.Layer, PositionalEncodingInterface): def __init__(self, d_model: int, dropout_rate: float, @@ -40,7 +82,7 @@ class PositionalEncoding(nn.Layer): max_len (int, optional): maximum input length. Defaults to 5000. reverse (bool, optional): Not used. Defaults to False. """ - super().__init__() + nn.Layer.__init__(self) self.d_model = d_model self.max_len = max_len self.xscale = paddle.to_tensor(math.sqrt(self.d_model)) @@ -85,7 +127,7 @@ class PositionalEncoding(nn.Layer): offset (int): start offset size (int): requried size of position encoding Returns: - paddle.Tensor: Corresponding encoding + paddle.Tensor: Corresponding position encoding """ assert offset + size < self.max_len return self.dropout(self.pe[:, offset:offset + size]) diff --git a/deepspeech/modules/encoder.py b/deepspeech/modules/encoder.py index 6ffb6465..435b6894 100644 --- a/deepspeech/modules/encoder.py +++ b/deepspeech/modules/encoder.py @@ -24,6 +24,7 @@ from deepspeech.modules.activation import get_activation from deepspeech.modules.attention import MultiHeadedAttention from deepspeech.modules.attention import RelPositionMultiHeadedAttention from deepspeech.modules.conformer_convolution import ConvolutionModule +from deepspeech.modules.embedding import NoPositionalEncoding from deepspeech.modules.embedding import PositionalEncoding from deepspeech.modules.embedding import RelPositionalEncoding from deepspeech.modules.encoder_layer import ConformerEncoderLayer @@ -76,7 +77,7 @@ class BaseEncoder(nn.Layer): input_layer (str): input layer type. optional [linear, conv2d, conv2d6, conv2d8] pos_enc_layer_type (str): Encoder positional encoding layer type. - opitonal [abs_pos, scaled_abs_pos, rel_pos] + opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos] normalize_before (bool): True: use layer_norm before each sub-block of a layer. False: use layer_norm after each sub-block of a layer. @@ -101,6 +102,8 @@ class BaseEncoder(nn.Layer): pos_enc_class = PositionalEncoding elif pos_enc_layer_type == "rel_pos": pos_enc_class = RelPositionalEncoding + elif pos_enc_layer_type == "no_pos": + pos_enc_class = NoPositionalEncoding else: raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) @@ -370,6 +373,41 @@ class TransformerEncoder(BaseEncoder): concat_after=concat_after) for _ in range(num_blocks) ]) + def forward_one_step( + self, + xs: paddle.Tensor, + masks: paddle.Tensor, + cache=None, ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Encode input frame. + + Args: + xs (paddle.Tensor): (Prefix) Input tensor. (B, T, D) + masks (paddle.Tensor): Mask tensor. (B, T, T) + cache (List[paddle.Tensor]): List of cache tensors. + + Returns: + paddle.Tensor: Output tensor. + paddle.Tensor: Mask tensor. + List[paddle.Tensor]: List of new cache tensors. + """ + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + + #TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor + xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0) + #TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor + masks = masks.astype(paddle.bool) + + if cache is None: + cache = [None for _ in range(len(self.encoders))] + new_cache = [] + for c, e in zip(cache, self.encoders): + xs, masks, _ = e(xs, masks, output_cache=c) + new_cache.append(xs) + if self.normalize_before: + xs = self.after_norm(xs) + return xs, masks, new_cache + class ConformerEncoder(BaseEncoder): """Conformer encoder module.""" diff --git a/deepspeech/modules/encoder_layer.py b/deepspeech/modules/encoder_layer.py index 1db556ca..6f49cfc8 100644 --- a/deepspeech/modules/encoder_layer.py +++ b/deepspeech/modules/encoder_layer.py @@ -71,7 +71,7 @@ class TransformerEncoderLayer(nn.Layer): self, x: paddle.Tensor, mask: paddle.Tensor, - pos_emb: paddle.Tensor, + pos_emb: Optional[paddle.Tensor]=None, mask_pad: Optional[paddle.Tensor]=None, output_cache: Optional[paddle.Tensor]=None, cnn_cache: Optional[paddle.Tensor]=None, @@ -82,8 +82,8 @@ class TransformerEncoderLayer(nn.Layer): mask (paddle.Tensor): Mask tensor for the input (#batch, time). pos_emb (paddle.Tensor): just for interface compatibility to ConformerEncoderLayer - mask_pad (paddle.Tensor): does not used in transformer layer, - just for unified api with conformer. + mask_pad (paddle.Tensor): not used here, it's for interface + compatibility to ConformerEncoderLayer output_cache (paddle.Tensor): Cache tensor of the output (#batch, time2, size), time2 < time in x. cnn_cache (paddle.Tensor): not used here, it's for interface diff --git a/deepspeech/modules/subsampling.py b/deepspeech/modules/subsampling.py index 3bed62f3..13e2c8ef 100644 --- a/deepspeech/modules/subsampling.py +++ b/deepspeech/modules/subsampling.py @@ -60,7 +60,8 @@ class LinearNoSubsampling(BaseSubsampling): self.out = nn.Sequential( nn.Linear(idim, odim), nn.LayerNorm(odim, epsilon=1e-12), - nn.Dropout(dropout_rate), ) + nn.Dropout(dropout_rate), + nn.ReLU(), ) self.right_context = 0 self.subsampling_rate = 1 @@ -83,7 +84,12 @@ class LinearNoSubsampling(BaseSubsampling): return x, pos_emb, x_mask -class Conv2dSubsampling4(BaseSubsampling): +class Conv2dSubsampling(BaseSubsampling): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class Conv2dSubsampling4(Conv2dSubsampling): """Convolutional 2D subsampling (to 1/4 length).""" def __init__(self, @@ -134,7 +140,7 @@ class Conv2dSubsampling4(BaseSubsampling): return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2] -class Conv2dSubsampling6(BaseSubsampling): +class Conv2dSubsampling6(Conv2dSubsampling): """Convolutional 2D subsampling (to 1/6 length).""" def __init__(self, @@ -187,7 +193,7 @@ class Conv2dSubsampling6(BaseSubsampling): return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-4:3] -class Conv2dSubsampling8(BaseSubsampling): +class Conv2dSubsampling8(Conv2dSubsampling): """Convolutional 2D subsampling (to 1/8 length).""" def __init__(self, diff --git a/deepspeech/training/extensions/plot.py b/deepspeech/training/extensions/plot.py new file mode 100644 index 00000000..6fbb4d4d --- /dev/null +++ b/deepspeech/training/extensions/plot.py @@ -0,0 +1,418 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import os + +import numpy as np + +from . import extension + + +class PlotAttentionReport(extension.Extension): + """Plot attention reporter. + + Args: + att_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_attentions): + Function of attention visualization. + data (list[tuple(str, dict[str, list[Any]])]): List json utt key items. + outdir (str): Directory to save figures. + converter (espnet.asr.*_backend.asr.CustomConverter): + Function to convert data. + device (int | torch.device): Device. + reverse (bool): If True, input and output length are reversed. + ikey (str): Key to access input + (for ASR/ST ikey="input", for MT ikey="output".) + iaxis (int): Dimension to access input + (for ASR/ST iaxis=0, for MT iaxis=1.) + okey (str): Key to access output + (for ASR/ST okey="input", MT okay="output".) + oaxis (int): Dimension to access output + (for ASR/ST oaxis=0, for MT oaxis=0.) + subsampling_factor (int): subsampling factor in encoder + + """ + + def __init__( + self, + att_vis_fn, + data, + outdir, + converter, + transform, + device, + reverse=False, + ikey="input", + iaxis=0, + okey="output", + oaxis=0, + subsampling_factor=1, ): + self.att_vis_fn = att_vis_fn + self.data = copy.deepcopy(data) + self.data_dict = {k: v for k, v in copy.deepcopy(data)} + # key is utterance ID + self.outdir = outdir + self.converter = converter + self.transform = transform + self.device = device + self.reverse = reverse + self.ikey = ikey + self.iaxis = iaxis + self.okey = okey + self.oaxis = oaxis + self.factor = subsampling_factor + if not os.path.exists(self.outdir): + os.makedirs(self.outdir) + + def __call__(self, trainer): + """Plot and save image file of att_ws matrix.""" + att_ws, uttid_list = self.get_attention_weights() + if isinstance(att_ws, list): # multi-encoder case + num_encs = len(att_ws) - 1 + # atts + for i in range(num_encs): + for idx, att_w in enumerate(att_ws[i]): + filename = "%s/%s.ep.{.updater.epoch}.att%d.png" % ( + self.outdir, uttid_list[idx], i + 1, ) + att_w = self.trim_attention_weight(uttid_list[idx], att_w) + np_filename = "%s/%s.ep.{.updater.epoch}.att%d.npy" % ( + self.outdir, uttid_list[idx], i + 1, ) + np.save(np_filename.format(trainer), att_w) + self._plot_and_save_attention(att_w, + filename.format(trainer)) + # han + for idx, att_w in enumerate(att_ws[num_encs]): + filename = "%s/%s.ep.{.updater.epoch}.han.png" % ( + self.outdir, uttid_list[idx], ) + att_w = self.trim_attention_weight(uttid_list[idx], att_w) + np_filename = "%s/%s.ep.{.updater.epoch}.han.npy" % ( + self.outdir, uttid_list[idx], ) + np.save(np_filename.format(trainer), att_w) + self._plot_and_save_attention( + att_w, filename.format(trainer), han_mode=True) + else: + for idx, att_w in enumerate(att_ws): + filename = "%s/%s.ep.{.updater.epoch}.png" % (self.outdir, + uttid_list[idx], ) + att_w = self.trim_attention_weight(uttid_list[idx], att_w) + np_filename = "%s/%s.ep.{.updater.epoch}.npy" % ( + self.outdir, uttid_list[idx], ) + np.save(np_filename.format(trainer), att_w) + self._plot_and_save_attention(att_w, filename.format(trainer)) + + def log_attentions(self, logger, step): + """Add image files of att_ws matrix to the tensorboard.""" + att_ws, uttid_list = self.get_attention_weights() + if isinstance(att_ws, list): # multi-encoder case + num_encs = len(att_ws) - 1 + # atts + for i in range(num_encs): + for idx, att_w in enumerate(att_ws[i]): + att_w = self.trim_attention_weight(uttid_list[idx], att_w) + plot = self.draw_attention_plot(att_w) + logger.add_figure( + "%s_att%d" % (uttid_list[idx], i + 1), + plot.gcf(), + step, ) + # han + for idx, att_w in enumerate(att_ws[num_encs]): + att_w = self.trim_attention_weight(uttid_list[idx], att_w) + plot = self.draw_han_plot(att_w) + logger.add_figure( + "%s_han" % (uttid_list[idx]), + plot.gcf(), + step, ) + else: + for idx, att_w in enumerate(att_ws): + att_w = self.trim_attention_weight(uttid_list[idx], att_w) + plot = self.draw_attention_plot(att_w) + logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step) + + def get_attention_weights(self): + """Return attention weights. + + Returns: + numpy.ndarray: attention weights. float. Its shape would be + differ from backend. + * pytorch-> 1) multi-head case => (B, H, Lmax, Tmax), 2) + other case => (B, Lmax, Tmax). + * chainer-> (B, Lmax, Tmax) + + """ + return_batch, uttid_list = self.transform(self.data, return_uttid=True) + batch = self.converter([return_batch], self.device) + if isinstance(batch, tuple): + att_ws = self.att_vis_fn(*batch) + else: + att_ws = self.att_vis_fn(**batch) + return att_ws, uttid_list + + def trim_attention_weight(self, uttid, att_w): + """Transform attention matrix with regard to self.reverse.""" + if self.reverse: + enc_key, enc_axis = self.okey, self.oaxis + dec_key, dec_axis = self.ikey, self.iaxis + else: + enc_key, enc_axis = self.ikey, self.iaxis + dec_key, dec_axis = self.okey, self.oaxis + dec_len = int(self.data_dict[uttid][dec_key][dec_axis]["shape"][0]) + enc_len = int(self.data_dict[uttid][enc_key][enc_axis]["shape"][0]) + if self.factor > 1: + enc_len //= self.factor + if len(att_w.shape) == 3: + att_w = att_w[:, :dec_len, :enc_len] + else: + att_w = att_w[:dec_len, :enc_len] + return att_w + + def draw_attention_plot(self, att_w): + """Plot the att_w matrix. + + Returns: + matplotlib.pyplot: pyplot object with attention matrix image. + + """ + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + plt.clf() + att_w = att_w.astype(np.float32) + if len(att_w.shape) == 3: + for h, aw in enumerate(att_w, 1): + plt.subplot(1, len(att_w), h) + plt.imshow(aw, aspect="auto") + plt.xlabel("Encoder Index") + plt.ylabel("Decoder Index") + else: + plt.imshow(att_w, aspect="auto") + plt.xlabel("Encoder Index") + plt.ylabel("Decoder Index") + plt.tight_layout() + return plt + + def draw_han_plot(self, att_w): + """Plot the att_w matrix for hierarchical attention. + + Returns: + matplotlib.pyplot: pyplot object with attention matrix image. + + """ + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + plt.clf() + if len(att_w.shape) == 3: + for h, aw in enumerate(att_w, 1): + legends = [] + plt.subplot(1, len(att_w), h) + for i in range(aw.shape[1]): + plt.plot(aw[:, i]) + legends.append("Att{}".format(i)) + plt.ylim([0, 1.0]) + plt.xlim([0, aw.shape[0]]) + plt.grid(True) + plt.ylabel("Attention Weight") + plt.xlabel("Decoder Index") + plt.legend(legends) + else: + legends = [] + for i in range(att_w.shape[1]): + plt.plot(att_w[:, i]) + legends.append("Att{}".format(i)) + plt.ylim([0, 1.0]) + plt.xlim([0, att_w.shape[0]]) + plt.grid(True) + plt.ylabel("Attention Weight") + plt.xlabel("Decoder Index") + plt.legend(legends) + plt.tight_layout() + return plt + + def _plot_and_save_attention(self, att_w, filename, han_mode=False): + if han_mode: + plt = self.draw_han_plot(att_w) + else: + plt = self.draw_attention_plot(att_w) + plt.savefig(filename) + plt.close() + + +class PlotCTCReport(extension.Extension): + """Plot CTC reporter. + + Args: + ctc_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_ctc_probs): + Function of CTC visualization. + data (list[tuple(str, dict[str, list[Any]])]): List json utt key items. + outdir (str): Directory to save figures. + converter (espnet.asr.*_backend.asr.CustomConverter): + Function to convert data. + device (int | torch.device): Device. + reverse (bool): If True, input and output length are reversed. + ikey (str): Key to access input + (for ASR/ST ikey="input", for MT ikey="output".) + iaxis (int): Dimension to access input + (for ASR/ST iaxis=0, for MT iaxis=1.) + okey (str): Key to access output + (for ASR/ST okey="input", MT okay="output".) + oaxis (int): Dimension to access output + (for ASR/ST oaxis=0, for MT oaxis=0.) + subsampling_factor (int): subsampling factor in encoder + + """ + + def __init__( + self, + ctc_vis_fn, + data, + outdir, + converter, + transform, + device, + reverse=False, + ikey="input", + iaxis=0, + okey="output", + oaxis=0, + subsampling_factor=1, ): + self.ctc_vis_fn = ctc_vis_fn + self.data = copy.deepcopy(data) + self.data_dict = {k: v for k, v in copy.deepcopy(data)} + # key is utterance ID + self.outdir = outdir + self.converter = converter + self.transform = transform + self.device = device + self.reverse = reverse + self.ikey = ikey + self.iaxis = iaxis + self.okey = okey + self.oaxis = oaxis + self.factor = subsampling_factor + if not os.path.exists(self.outdir): + os.makedirs(self.outdir) + + def __call__(self, trainer): + """Plot and save image file of ctc prob.""" + ctc_probs, uttid_list = self.get_ctc_probs() + if isinstance(ctc_probs, list): # multi-encoder case + num_encs = len(ctc_probs) - 1 + for i in range(num_encs): + for idx, ctc_prob in enumerate(ctc_probs[i]): + filename = "%s/%s.ep.{.updater.epoch}.ctc%d.png" % ( + self.outdir, uttid_list[idx], i + 1, ) + ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob) + np_filename = "%s/%s.ep.{.updater.epoch}.ctc%d.npy" % ( + self.outdir, uttid_list[idx], i + 1, ) + np.save(np_filename.format(trainer), ctc_prob) + self._plot_and_save_ctc(ctc_prob, filename.format(trainer)) + else: + for idx, ctc_prob in enumerate(ctc_probs): + filename = "%s/%s.ep.{.updater.epoch}.png" % (self.outdir, + uttid_list[idx], ) + ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob) + np_filename = "%s/%s.ep.{.updater.epoch}.npy" % ( + self.outdir, uttid_list[idx], ) + np.save(np_filename.format(trainer), ctc_prob) + self._plot_and_save_ctc(ctc_prob, filename.format(trainer)) + + def log_ctc_probs(self, logger, step): + """Add image files of ctc probs to the tensorboard.""" + ctc_probs, uttid_list = self.get_ctc_probs() + if isinstance(ctc_probs, list): # multi-encoder case + num_encs = len(ctc_probs) - 1 + for i in range(num_encs): + for idx, ctc_prob in enumerate(ctc_probs[i]): + ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob) + plot = self.draw_ctc_plot(ctc_prob) + logger.add_figure( + "%s_ctc%d" % (uttid_list[idx], i + 1), + plot.gcf(), + step, ) + else: + for idx, ctc_prob in enumerate(ctc_probs): + ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob) + plot = self.draw_ctc_plot(ctc_prob) + logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step) + + def get_ctc_probs(self): + """Return CTC probs. + + Returns: + numpy.ndarray: CTC probs. float. Its shape would be + differ from backend. (B, Tmax, vocab). + + """ + return_batch, uttid_list = self.transform(self.data, return_uttid=True) + batch = self.converter([return_batch], self.device) + if isinstance(batch, tuple): + probs = self.ctc_vis_fn(*batch) + else: + probs = self.ctc_vis_fn(**batch) + return probs, uttid_list + + def trim_ctc_prob(self, uttid, prob): + """Trim CTC posteriors accoding to input lengths.""" + enc_len = int(self.data_dict[uttid][self.ikey][self.iaxis]["shape"][0]) + if self.factor > 1: + enc_len //= self.factor + prob = prob[:enc_len] + return prob + + def draw_ctc_plot(self, ctc_prob): + """Plot the ctc_prob matrix. + + Returns: + matplotlib.pyplot: pyplot object with CTC prob matrix image. + + """ + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + ctc_prob = ctc_prob.astype(np.float32) + + plt.clf() + topk_ids = np.argsort(ctc_prob, axis=1) + n_frames, vocab = ctc_prob.shape + times_probs = np.arange(n_frames) + + plt.figure(figsize=(20, 8)) + + # NOTE: index 0 is reserved for blank + for idx in set(topk_ids.reshape(-1).tolist()): + if idx == 0: + plt.plot( + times_probs, + ctc_prob[:, 0], + ":", + label="", + color="grey") + else: + plt.plot(times_probs, ctc_prob[:, idx]) + plt.xlabel(u"Input [frame]", fontsize=12) + plt.ylabel("Posteriors", fontsize=12) + plt.xticks(list(range(0, int(n_frames) + 1, 10))) + plt.yticks(list(range(0, 2, 1))) + plt.tight_layout() + return plt + + def _plot_and_save_ctc(self, ctc_prob, filename): + plt = self.draw_ctc_plot(ctc_prob) + plt.savefig(filename) + plt.close() diff --git a/deepspeech/training/triggers/__init__.py b/deepspeech/training/triggers/__init__.py index 1a7c4292..185a92b8 100644 --- a/deepspeech/training/triggers/__init__.py +++ b/deepspeech/training/triggers/__init__.py @@ -11,18 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .interval_trigger import IntervalTrigger - - -def never_fail_trigger(trainer): - return False - - -def get_trigger(trigger): - if trigger is None: - return never_fail_trigger - if callable(trigger): - return trigger - else: - trigger = IntervalTrigger(*trigger) - return trigger diff --git a/deepspeech/training/triggers/compare_value_trigger.py b/deepspeech/training/triggers/compare_value_trigger.py new file mode 100644 index 00000000..efb928e2 --- /dev/null +++ b/deepspeech/training/triggers/compare_value_trigger.py @@ -0,0 +1,61 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ..reporter import DictSummary +from .utils import get_trigger + + +class CompareValueTrigger(): + """Trigger invoked when key value getting bigger or lower than before. + + Args: + key (str) : Key of value. + compare_fn ((float, float) -> bool) : Function to compare the values. + trigger (tuple(int, str)) : Trigger that decide the comparison interval. + + """ + + def __init__(self, key, compare_fn, trigger=(1, "epoch")): + self._key = key + self._best_value = None + self._interval_trigger = get_trigger(trigger) + self._init_summary() + self._compare_fn = compare_fn + + def __call__(self, trainer): + """Get value related to the key and compare with current value.""" + observation = trainer.observation + summary = self._summary + key = self._key + if key in observation: + summary.add({key: observation[key]}) + + if not self._interval_trigger(trainer): + return False + + stats = summary.compute_mean() + value = float(stats[key]) # copy to CPU + self._init_summary() + + if self._best_value is None: + # initialize best value + self._best_value = value + return False + elif self._compare_fn(self._best_value, value): + return True + else: + self._best_value = value + return False + + def _init_summary(self): + self._summary = DictSummary() diff --git a/deepspeech/training/triggers/time_trigger.py b/deepspeech/training/triggers/time_trigger.py index ea8fe562..e31179a9 100644 --- a/deepspeech/training/triggers/time_trigger.py +++ b/deepspeech/training/triggers/time_trigger.py @@ -30,3 +30,12 @@ class TimeTrigger(): return True else: return False + + def state_dict(self): + state_dict = { + "next_time": self._next_time, + } + return state_dict + + def set_state_dict(self, state_dict): + self._next_time = state_dict['next_time'] diff --git a/deepspeech/training/triggers/utils.py b/deepspeech/training/triggers/utils.py new file mode 100644 index 00000000..1a7c4292 --- /dev/null +++ b/deepspeech/training/triggers/utils.py @@ -0,0 +1,28 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .interval_trigger import IntervalTrigger + + +def never_fail_trigger(trainer): + return False + + +def get_trigger(trigger): + if trigger is None: + return never_fail_trigger + if callable(trigger): + return trigger + else: + trigger = IntervalTrigger(*trigger) + return trigger diff --git a/deepspeech/transform/__init__.py b/deepspeech/transform/__init__.py new file mode 100644 index 00000000..185a92b8 --- /dev/null +++ b/deepspeech/transform/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/deepspeech/transform/add_deltas.py b/deepspeech/transform/add_deltas.py new file mode 100644 index 00000000..68f44d41 --- /dev/null +++ b/deepspeech/transform/add_deltas.py @@ -0,0 +1,41 @@ +import numpy as np + + +def delta(feat, window): + assert window > 0 + delta_feat = np.zeros_like(feat) + for i in range(1, window + 1): + delta_feat[:-i] += i * feat[i:] + delta_feat[i:] += -i * feat[:-i] + delta_feat[-i:] += i * feat[-1] + delta_feat[:i] += -i * feat[0] + delta_feat /= 2 * sum(i ** 2 for i in range(1, window + 1)) + return delta_feat + + +def add_deltas(x, window=2, order=2): + """ + Args: + x (np.ndarray): speech feat, (T, D). + + Return: + np.ndarray: (T, (1+order)*D) + """ + feats = [x] + for _ in range(order): + feats.append(delta(feats[-1], window)) + return np.concatenate(feats, axis=1) + + +class AddDeltas(): + def __init__(self, window=2, order=2): + self.window = window + self.order = order + + def __repr__(self): + return "{name}(window={window}, order={order}".format( + name=self.__class__.__name__, window=self.window, order=self.order + ) + + def __call__(self, x): + return add_deltas(x, window=self.window, order=self.order) diff --git a/deepspeech/transform/channel_selector.py b/deepspeech/transform/channel_selector.py new file mode 100644 index 00000000..1ac9e350 --- /dev/null +++ b/deepspeech/transform/channel_selector.py @@ -0,0 +1,45 @@ +import numpy + + +class ChannelSelector(): + """Select 1ch from multi-channel signal""" + + def __init__(self, train_channel="random", eval_channel=0, axis=1): + self.train_channel = train_channel + self.eval_channel = eval_channel + self.axis = axis + + def __repr__(self): + return ( + "{name}(train_channel={train_channel}, " + "eval_channel={eval_channel}, axis={axis})".format( + name=self.__class__.__name__, + train_channel=self.train_channel, + eval_channel=self.eval_channel, + axis=self.axis, + ) + ) + + def __call__(self, x, train=True): + # Assuming x: [Time, Channel] by default + + if x.ndim <= self.axis: + # If the dimension is insufficient, then unsqueeze + # (e.g [Time] -> [Time, 1]) + ind = tuple( + slice(None) if i < x.ndim else None for i in range(self.axis + 1) + ) + x = x[ind] + + if train: + channel = self.train_channel + else: + channel = self.eval_channel + + if channel == "random": + ch = numpy.random.randint(0, x.shape[self.axis]) + else: + ch = channel + + ind = tuple(slice(None) if i != self.axis else ch for i in range(x.ndim)) + return x[ind] diff --git a/deepspeech/transform/cmvn.py b/deepspeech/transform/cmvn.py new file mode 100644 index 00000000..5d318590 --- /dev/null +++ b/deepspeech/transform/cmvn.py @@ -0,0 +1,158 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io + +import h5py +import kaldiio +import numpy as np + + +class CMVN(): + "Apply Global/Spk CMVN/iverserCMVN." + + def __init__( + self, + stats, + norm_means=True, + norm_vars=False, + filetype="mat", + utt2spk=None, + spk2utt=None, + reverse=False, + std_floor=1.0e-20, ): + self.stats_file = stats + self.norm_means = norm_means + self.norm_vars = norm_vars + self.reverse = reverse + + if isinstance(stats, dict): + stats_dict = dict(stats) + else: + # Use for global CMVN + if filetype == "mat": + stats_dict = {None: kaldiio.load_mat(stats)} + # Use for global CMVN + elif filetype == "npy": + stats_dict = {None: np.load(stats)} + # Use for speaker CMVN + elif filetype == "ark": + self.accept_uttid = True + stats_dict = dict(kaldiio.load_ark(stats)) + # Use for speaker CMVN + elif filetype == "hdf5": + self.accept_uttid = True + stats_dict = h5py.File(stats) + else: + raise ValueError("Not supporting filetype={}".format(filetype)) + + if utt2spk is not None: + self.utt2spk = {} + with io.open(utt2spk, "r", encoding="utf-8") as f: + for line in f: + utt, spk = line.rstrip().split(None, 1) + self.utt2spk[utt] = spk + elif spk2utt is not None: + self.utt2spk = {} + with io.open(spk2utt, "r", encoding="utf-8") as f: + for line in f: + spk, utts = line.rstrip().split(None, 1) + for utt in utts.split(): + self.utt2spk[utt] = spk + else: + self.utt2spk = None + + # Kaldi makes a matrix for CMVN which has a shape of (2, feat_dim + 1), + # and the first vector contains the sum of feats and the second is + # the sum of squares. The last value of the first, i.e. stats[0,-1], + # is the number of samples for this statistics. + self.bias = {} + self.scale = {} + for spk, stats in stats_dict.items(): + assert len(stats) == 2, stats.shape + + count = stats[0, -1] + + # If the feature has two or more dimensions + if not (np.isscalar(count) or isinstance(count, (int, float))): + # The first is only used + count = count.flatten()[0] + + mean = stats[0, :-1] / count + # V(x) = E(x^2) - (E(x))^2 + var = stats[1, :-1] / count - mean * mean + std = np.maximum(np.sqrt(var), std_floor) + self.bias[spk] = -mean + self.scale[spk] = 1 / std + + def __repr__(self): + return ("{name}(stats_file={stats_file}, " + "norm_means={norm_means}, norm_vars={norm_vars}, " + "reverse={reverse})".format( + name=self.__class__.__name__, + stats_file=self.stats_file, + norm_means=self.norm_means, + norm_vars=self.norm_vars, + reverse=self.reverse, )) + + def __call__(self, x, uttid=None): + if self.utt2spk is not None: + spk = self.utt2spk[uttid] + else: + spk = uttid + + if not self.reverse: + # apply cmvn + if self.norm_means: + x = np.add(x, self.bias[spk]) + if self.norm_vars: + x = np.multiply(x, self.scale[spk]) + + else: + # apply reverse cmvn + if self.norm_vars: + x = np.divide(x, self.scale[spk]) + if self.norm_means: + x = np.subtract(x, self.bias[spk]) + + return x + + +class UtteranceCMVN(): + "Apply Utterance CMVN" + + def __init__(self, norm_means=True, norm_vars=False, std_floor=1.0e-20): + self.norm_means = norm_means + self.norm_vars = norm_vars + self.std_floor = std_floor + + def __repr__(self): + return "{name}(norm_means={norm_means}, norm_vars={norm_vars})".format( + name=self.__class__.__name__, + norm_means=self.norm_means, + norm_vars=self.norm_vars, ) + + def __call__(self, x, uttid=None): + # x: [Time, Dim] + square_sums = (x**2).sum(axis=0) + mean = x.mean(axis=0) + + if self.norm_means: + x = np.subtract(x, mean) + + if self.norm_vars: + var = square_sums / x.shape[0] - mean**2 + std = np.maximum(np.sqrt(var), self.std_floor) + x = np.divide(x, std) + + return x diff --git a/deepspeech/transform/functional.py b/deepspeech/transform/functional.py new file mode 100644 index 00000000..5eec6cc1 --- /dev/null +++ b/deepspeech/transform/functional.py @@ -0,0 +1,71 @@ +import inspect + +from deepspeech.transform.transform_interface import TransformInterface +from deepspeech.utils.check_kwargs import check_kwargs + + +class FuncTrans(TransformInterface): + """Functional Transformation + + WARNING: + Builtin or C/C++ functions may not work properly + because this class heavily depends on the `inspect` module. + + Usage: + + >>> def foo_bar(x, a=1, b=2): + ... '''Foo bar + ... :param x: input + ... :param int a: default 1 + ... :param int b: default 2 + ... ''' + ... return x + a - b + + + >>> class FooBar(FuncTrans): + ... _func = foo_bar + ... __doc__ = foo_bar.__doc__ + """ + + _func = None + + def __init__(self, **kwargs): + self.kwargs = kwargs + check_kwargs(self.func, kwargs) + + def __call__(self, x): + return self.func(x, **self.kwargs) + + @classmethod + def add_arguments(cls, parser): + fname = cls._func.__name__.replace("_", "-") + group = parser.add_argument_group(fname + " transformation setting") + for k, v in cls.default_params().items(): + # TODO(karita): get help and choices from docstring? + attr = k.replace("_", "-") + group.add_argument(f"--{fname}-{attr}", default=v, type=type(v)) + return parser + + @property + def func(self): + return type(self)._func + + @classmethod + def default_params(cls): + try: + d = dict(inspect.signature(cls._func).parameters) + except ValueError: + d = dict() + return { + k: v.default for k, v in d.items() if v.default != inspect.Parameter.empty + } + + def __repr__(self): + params = self.default_params() + params.update(**self.kwargs) + ret = self.__class__.__name__ + "(" + if len(params) == 0: + return ret + ")" + for k, v in params.items(): + ret += "{}={}, ".format(k, v) + return ret[:-2] + ")" diff --git a/deepspeech/transform/perturb.py b/deepspeech/transform/perturb.py new file mode 100644 index 00000000..05766cad --- /dev/null +++ b/deepspeech/transform/perturb.py @@ -0,0 +1,343 @@ +import librosa +import numpy +import scipy +import soundfile + +from deepspeech.io.reader import SoundHDF5File + +class SpeedPerturbation(): + """SpeedPerturbation + + The speed perturbation in kaldi uses sox-speed instead of sox-tempo, + and sox-speed just to resample the input, + i.e pitch and tempo are changed both. + + "Why use speed option instead of tempo -s in SoX for speed perturbation" + https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8 + + Warning: + This function is very slow because of resampling. + I recommmend to apply speed-perturb outside the training using sox. + + """ + + def __init__( + self, + lower=0.9, + upper=1.1, + utt2ratio=None, + keep_length=True, + res_type="kaiser_best", + seed=None, + ): + self.res_type = res_type + self.keep_length = keep_length + self.state = numpy.random.RandomState(seed) + + if utt2ratio is not None: + self.utt2ratio = {} + # Use the scheduled ratio for each utterances + self.utt2ratio_file = utt2ratio + self.lower = None + self.upper = None + self.accept_uttid = True + + with open(utt2ratio, "r") as f: + for line in f: + utt, ratio = line.rstrip().split(None, 1) + ratio = float(ratio) + self.utt2ratio[utt] = ratio + else: + self.utt2ratio = None + # The ratio is given on runtime randomly + self.lower = lower + self.upper = upper + + def __repr__(self): + if self.utt2ratio is None: + return "{}(lower={}, upper={}, " "keep_length={}, res_type={})".format( + self.__class__.__name__, + self.lower, + self.upper, + self.keep_length, + self.res_type, + ) + else: + return "{}({}, res_type={})".format( + self.__class__.__name__, self.utt2ratio_file, self.res_type + ) + + def __call__(self, x, uttid=None, train=True): + if not train: + return x + + x = x.astype(numpy.float32) + if self.accept_uttid: + ratio = self.utt2ratio[uttid] + else: + ratio = self.state.uniform(self.lower, self.upper) + + # Note1: resample requires the sampling-rate of input and output, + # but actually only the ratio is used. + y = librosa.resample(x, ratio, 1, res_type=self.res_type) + + if self.keep_length: + diff = abs(len(x) - len(y)) + if len(y) > len(x): + # Truncate noise + y = y[diff // 2 : -((diff + 1) // 2)] + elif len(y) < len(x): + # Assume the time-axis is the first: (Time, Channel) + pad_width = [(diff // 2, (diff + 1) // 2)] + [ + (0, 0) for _ in range(y.ndim - 1) + ] + y = numpy.pad( + y, pad_width=pad_width, constant_values=0, mode="constant" + ) + return y + + +class BandpassPerturbation(): + """BandpassPerturbation + + Randomly dropout along the frequency axis. + + The original idea comes from the following: + "randomly-selected frequency band was cut off under the constraint of + leaving at least 1,000 Hz band within the range of less than 4,000Hz." + (The Hitachi/JHU CHiME-5 system: Advances in speech recognition for + everyday home environments using multiple microphone arrays; + http://spandh.dcs.shef.ac.uk/chime_workshop/papers/CHiME_2018_paper_kanda.pdf) + + """ + + def __init__(self, lower=0.0, upper=0.75, seed=None, axes=(-1,)): + self.lower = lower + self.upper = upper + self.state = numpy.random.RandomState(seed) + # x_stft: (Time, Channel, Freq) + self.axes = axes + + def __repr__(self): + return "{}(lower={}, upper={})".format( + self.__class__.__name__, self.lower, self.upper + ) + + def __call__(self, x_stft, uttid=None, train=True): + if not train: + return x_stft + + if x_stft.ndim == 1: + raise RuntimeError( + "Input in time-freq domain: " "(Time, Channel, Freq) or (Time, Freq)" + ) + + ratio = self.state.uniform(self.lower, self.upper) + axes = [i if i >= 0 else x_stft.ndim - i for i in self.axes] + shape = [s if i in axes else 1 for i, s in enumerate(x_stft.shape)] + + mask = self.state.randn(*shape) > ratio + x_stft *= mask + return x_stft + + +class VolumePerturbation(): + def __init__(self, lower=-1.6, upper=1.6, utt2ratio=None, dbunit=True, seed=None): + self.dbunit = dbunit + self.utt2ratio_file = utt2ratio + self.lower = lower + self.upper = upper + self.state = numpy.random.RandomState(seed) + + if utt2ratio is not None: + # Use the scheduled ratio for each utterances + self.utt2ratio = {} + self.lower = None + self.upper = None + self.accept_uttid = True + + with open(utt2ratio, "r") as f: + for line in f: + utt, ratio = line.rstrip().split(None, 1) + ratio = float(ratio) + self.utt2ratio[utt] = ratio + else: + # The ratio is given on runtime randomly + self.utt2ratio = None + + def __repr__(self): + if self.utt2ratio is None: + return "{}(lower={}, upper={}, dbunit={})".format( + self.__class__.__name__, self.lower, self.upper, self.dbunit + ) + else: + return '{}("{}", dbunit={})'.format( + self.__class__.__name__, self.utt2ratio_file, self.dbunit + ) + + def __call__(self, x, uttid=None, train=True): + if not train: + return x + + x = x.astype(numpy.float32) + + if self.accept_uttid: + ratio = self.utt2ratio[uttid] + else: + ratio = self.state.uniform(self.lower, self.upper) + if self.dbunit: + ratio = 10 ** (ratio / 20) + return x * ratio + + +class NoiseInjection(): + """Add isotropic noise""" + + def __init__( + self, + utt2noise=None, + lower=-20, + upper=-5, + utt2ratio=None, + filetype="list", + dbunit=True, + seed=None, + ): + self.utt2noise_file = utt2noise + self.utt2ratio_file = utt2ratio + self.filetype = filetype + self.dbunit = dbunit + self.lower = lower + self.upper = upper + self.state = numpy.random.RandomState(seed) + + if utt2ratio is not None: + # Use the scheduled ratio for each utterances + self.utt2ratio = {} + with open(utt2noise, "r") as f: + for line in f: + utt, snr = line.rstrip().split(None, 1) + snr = float(snr) + self.utt2ratio[utt] = snr + else: + # The ratio is given on runtime randomly + self.utt2ratio = None + + if utt2noise is not None: + self.utt2noise = {} + if filetype == "list": + with open(utt2noise, "r") as f: + for line in f: + utt, filename = line.rstrip().split(None, 1) + signal, rate = soundfile.read(filename, dtype="int16") + # Load all files in memory + self.utt2noise[utt] = (signal, rate) + + elif filetype == "sound.hdf5": + self.utt2noise = SoundHDF5File(utt2noise, "r") + else: + raise ValueError(filetype) + else: + self.utt2noise = None + + if utt2noise is not None and utt2ratio is not None: + if set(self.utt2ratio) != set(self.utt2noise): + raise RuntimeError( + "The uttids mismatch between {} and {}".format(utt2ratio, utt2noise) + ) + + def __repr__(self): + if self.utt2ratio is None: + return "{}(lower={}, upper={}, dbunit={})".format( + self.__class__.__name__, self.lower, self.upper, self.dbunit + ) + else: + return '{}("{}", dbunit={})'.format( + self.__class__.__name__, self.utt2ratio_file, self.dbunit + ) + + def __call__(self, x, uttid=None, train=True): + if not train: + return x + x = x.astype(numpy.float32) + + # 1. Get ratio of noise to signal in sound pressure level + if uttid is not None and self.utt2ratio is not None: + ratio = self.utt2ratio[uttid] + else: + ratio = self.state.uniform(self.lower, self.upper) + + if self.dbunit: + ratio = 10 ** (ratio / 20) + scale = ratio * numpy.sqrt((x ** 2).mean()) + + # 2. Get noise + if self.utt2noise is not None: + # Get noise from the external source + if uttid is not None: + noise, rate = self.utt2noise[uttid] + else: + # Randomly select the noise source + noise = self.state.choice(list(self.utt2noise.values())) + # Normalize the level + noise /= numpy.sqrt((noise ** 2).mean()) + + # Adjust the noise length + diff = abs(len(x) - len(noise)) + offset = self.state.randint(0, diff) + if len(noise) > len(x): + # Truncate noise + noise = noise[offset : -(diff - offset)] + else: + noise = numpy.pad(noise, pad_width=[offset, diff - offset], mode="wrap") + + else: + # Generate white noise + noise = self.state.normal(0, 1, x.shape) + + # 3. Add noise to signal + return x + noise * scale + + +class RIRConvolve(): + def __init__(self, utt2rir, filetype="list"): + self.utt2rir_file = utt2rir + self.filetype = filetype + + self.utt2rir = {} + if filetype == "list": + with open(utt2rir, "r") as f: + for line in f: + utt, filename = line.rstrip().split(None, 1) + signal, rate = soundfile.read(filename, dtype="int16") + self.utt2rir[utt] = (signal, rate) + + elif filetype == "sound.hdf5": + self.utt2rir = SoundHDF5File(utt2rir, "r") + else: + raise NotImplementedError(filetype) + + def __repr__(self): + return '{}("{}")'.format(self.__class__.__name__, self.utt2rir_file) + + def __call__(self, x, uttid=None, train=True): + if not train: + return x + + x = x.astype(numpy.float32) + + if x.ndim != 1: + # Must be single channel + raise RuntimeError( + "Input x must be one dimensional array, but got {}".format(x.shape) + ) + + rir, rate = self.utt2rir[uttid] + if rir.ndim == 2: + # FIXME(kamo): Use chainer.convolution_1d? + # return [Time, Channel] + return numpy.stack( + [scipy.convolve(x, r, mode="same") for r in rir], axis=-1 + ) + else: + return scipy.convolve(x, rir, mode="same") diff --git a/deepspeech/transform/spec_augment.py b/deepspeech/transform/spec_augment.py new file mode 100644 index 00000000..feb712df --- /dev/null +++ b/deepspeech/transform/spec_augment.py @@ -0,0 +1,202 @@ +"""Spec Augment module for preprocessing i.e., data augmentation""" + +import random + +import numpy +from PIL import Image +from PIL.Image import BICUBIC + +from deepspeech.transform.functional import FuncTrans + + +def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"): + """time warp for spec augment + + move random center frame by the random width ~ uniform(-window, window) + :param numpy.ndarray x: spectrogram (time, freq) + :param int max_time_warp: maximum time frames to warp + :param bool inplace: overwrite x with the result + :param str mode: "PIL" (default, fast, not differentiable) or "sparse_image_warp" + (slow, differentiable) + :returns numpy.ndarray: time warped spectrogram (time, freq) + """ + window = max_time_warp + if mode == "PIL": + t = x.shape[0] + if t - window <= window: + return x + # NOTE: randrange(a, b) emits a, a + 1, ..., b - 1 + center = random.randrange(window, t - window) + warped = random.randrange(center - window, center + window) + 1 # 1 ... t - 1 + + left = Image.fromarray(x[:center]).resize((x.shape[1], warped), BICUBIC) + right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped), BICUBIC) + if inplace: + x[:warped] = left + x[warped:] = right + return x + return numpy.concatenate((left, right), 0) + elif mode == "sparse_image_warp": + import paddle + + from espnet.utils import spec_augment + + # TODO(karita): make this differentiable again + return spec_augment.time_warp(paddle.to_tensor(x), window).numpy() + else: + raise NotImplementedError( + "unknown resize mode: " + + mode + + ", choose one from (PIL, sparse_image_warp)." + ) + + +class TimeWarp(FuncTrans): + _func = time_warp + __doc__ = time_warp.__doc__ + + def __call__(self, x, train): + if not train: + return x + return super().__call__(x) + + +def freq_mask(x, F=30, n_mask=2, replace_with_zero=True, inplace=False): + """freq mask for spec agument + + :param numpy.ndarray x: (time, freq) + :param int n_mask: the number of masks + :param bool inplace: overwrite + :param bool replace_with_zero: pad zero on mask if true else use mean + """ + if inplace: + cloned = x + else: + cloned = x.copy() + + num_mel_channels = cloned.shape[1] + fs = numpy.random.randint(0, F, size=(n_mask, 2)) + + for f, mask_end in fs: + f_zero = random.randrange(0, num_mel_channels - f) + mask_end += f_zero + + # avoids randrange error if values are equal and range is empty + if f_zero == f_zero + f: + continue + + if replace_with_zero: + cloned[:, f_zero:mask_end] = 0 + else: + cloned[:, f_zero:mask_end] = cloned.mean() + return cloned + + +class FreqMask(FuncTrans): + _func = freq_mask + __doc__ = freq_mask.__doc__ + + def __call__(self, x, train): + if not train: + return x + return super().__call__(x) + + +def time_mask(spec, T=40, n_mask=2, replace_with_zero=True, inplace=False): + """freq mask for spec agument + + :param numpy.ndarray spec: (time, freq) + :param int n_mask: the number of masks + :param bool inplace: overwrite + :param bool replace_with_zero: pad zero on mask if true else use mean + """ + if inplace: + cloned = spec + else: + cloned = spec.copy() + len_spectro = cloned.shape[0] + ts = numpy.random.randint(0, T, size=(n_mask, 2)) + for t, mask_end in ts: + # avoid randint range error + if len_spectro - t <= 0: + continue + t_zero = random.randrange(0, len_spectro - t) + + # avoids randrange error if values are equal and range is empty + if t_zero == t_zero + t: + continue + + mask_end += t_zero + if replace_with_zero: + cloned[t_zero:mask_end] = 0 + else: + cloned[t_zero:mask_end] = cloned.mean() + return cloned + + +class TimeMask(FuncTrans): + _func = time_mask + __doc__ = time_mask.__doc__ + + def __call__(self, x, train): + if not train: + return x + return super().__call__(x) + + +def spec_augment( + x, + resize_mode="PIL", + max_time_warp=80, + max_freq_width=27, + n_freq_mask=2, + max_time_width=100, + n_time_mask=2, + inplace=True, + replace_with_zero=True, +): + """spec agument + + apply random time warping and time/freq masking + default setting is based on LD (Librispeech double) in Table 2 + https://arxiv.org/pdf/1904.08779.pdf + + :param numpy.ndarray x: (time, freq) + :param str resize_mode: "PIL" (fast, nondifferentiable) or "sparse_image_warp" + (slow, differentiable) + :param int max_time_warp: maximum frames to warp the center frame in spectrogram (W) + :param int freq_mask_width: maximum width of the random freq mask (F) + :param int n_freq_mask: the number of the random freq mask (m_F) + :param int time_mask_width: maximum width of the random time mask (T) + :param int n_time_mask: the number of the random time mask (m_T) + :param bool inplace: overwrite intermediate array + :param bool replace_with_zero: pad zero on mask if true else use mean + """ + assert isinstance(x, numpy.ndarray) + assert x.ndim == 2 + x = time_warp(x, max_time_warp, inplace=inplace, mode=resize_mode) + x = freq_mask( + x, + max_freq_width, + n_freq_mask, + inplace=inplace, + replace_with_zero=replace_with_zero, + ) + x = time_mask( + x, + max_time_width, + n_time_mask, + inplace=inplace, + replace_with_zero=replace_with_zero, + ) + return x + + +class SpecAugment(FuncTrans): + _func = spec_augment + __doc__ = spec_augment.__doc__ + + def __call__(self, x, train): + if not train: + return x + return super().__call__(x) diff --git a/deepspeech/transform/spectrogram.py b/deepspeech/transform/spectrogram.py new file mode 100644 index 00000000..68d47627 --- /dev/null +++ b/deepspeech/transform/spectrogram.py @@ -0,0 +1,307 @@ +import librosa +import numpy as np + + +def stft( + x, n_fft, n_shift, win_length=None, window="hann", center=True, pad_mode="reflect" +): + # x: [Time, Channel] + if x.ndim == 1: + single_channel = True + # x: [Time] -> [Time, Channel] + x = x[:, None] + else: + single_channel = False + x = x.astype(np.float32) + + # FIXME(kamo): librosa.stft can't use multi-channel? + # x: [Time, Channel, Freq] + x = np.stack( + [ + librosa.stft( + x[:, ch], + n_fft=n_fft, + hop_length=n_shift, + win_length=win_length, + window=window, + center=center, + pad_mode=pad_mode, + ).T + for ch in range(x.shape[1]) + ], + axis=1, + ) + + if single_channel: + # x: [Time, Channel, Freq] -> [Time, Freq] + x = x[:, 0] + return x + + +def istft(x, n_shift, win_length=None, window="hann", center=True): + # x: [Time, Channel, Freq] + if x.ndim == 2: + single_channel = True + # x: [Time, Freq] -> [Time, Channel, Freq] + x = x[:, None, :] + else: + single_channel = False + + # x: [Time, Channel] + x = np.stack( + [ + librosa.istft( + x[:, ch].T, # [Time, Freq] -> [Freq, Time] + hop_length=n_shift, + win_length=win_length, + window=window, + center=center, + ) + for ch in range(x.shape[1]) + ], + axis=1, + ) + + if single_channel: + # x: [Time, Channel] -> [Time] + x = x[:, 0] + return x + + +def stft2logmelspectrogram(x_stft, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10): + # x_stft: (Time, Channel, Freq) or (Time, Freq) + fmin = 0 if fmin is None else fmin + fmax = fs / 2 if fmax is None else fmax + + # spc: (Time, Channel, Freq) or (Time, Freq) + spc = np.abs(x_stft) + # mel_basis: (Mel_freq, Freq) + mel_basis = librosa.filters.mel(fs, n_fft, n_mels, fmin, fmax) + # lmspc: (Time, Channel, Mel_freq) or (Time, Mel_freq) + lmspc = np.log10(np.maximum(eps, np.dot(spc, mel_basis.T))) + + return lmspc + + +def spectrogram(x, n_fft, n_shift, win_length=None, window="hann"): + # x: (Time, Channel) -> spc: (Time, Channel, Freq) + spc = np.abs(stft(x, n_fft, n_shift, win_length, window=window)) + return spc + + +def logmelspectrogram( + x, + fs, + n_mels, + n_fft, + n_shift, + win_length=None, + window="hann", + fmin=None, + fmax=None, + eps=1e-10, + pad_mode="reflect", +): + # stft: (Time, Channel, Freq) or (Time, Freq) + x_stft = stft( + x, + n_fft=n_fft, + n_shift=n_shift, + win_length=win_length, + window=window, + pad_mode=pad_mode, + ) + + return stft2logmelspectrogram( + x_stft, fs=fs, n_mels=n_mels, n_fft=n_fft, fmin=fmin, fmax=fmax, eps=eps + ) + + +class Spectrogram(): + def __init__(self, n_fft, n_shift, win_length=None, window="hann"): + self.n_fft = n_fft + self.n_shift = n_shift + self.win_length = win_length + self.window = window + + def __repr__(self): + return ( + "{name}(n_fft={n_fft}, n_shift={n_shift}, " + "win_length={win_length}, window={window})".format( + name=self.__class__.__name__, + n_fft=self.n_fft, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, + ) + ) + + def __call__(self, x): + return spectrogram( + x, + n_fft=self.n_fft, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, + ) + + +class LogMelSpectrogram(): + def __init__( + self, + fs, + n_mels, + n_fft, + n_shift, + win_length=None, + window="hann", + fmin=None, + fmax=None, + eps=1e-10, + ): + self.fs = fs + self.n_mels = n_mels + self.n_fft = n_fft + self.n_shift = n_shift + self.win_length = win_length + self.window = window + self.fmin = fmin + self.fmax = fmax + self.eps = eps + + def __repr__(self): + return ( + "{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, " + "n_shift={n_shift}, win_length={win_length}, window={window}, " + "fmin={fmin}, fmax={fmax}, eps={eps}))".format( + name=self.__class__.__name__, + fs=self.fs, + n_mels=self.n_mels, + n_fft=self.n_fft, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, + fmin=self.fmin, + fmax=self.fmax, + eps=self.eps, + ) + ) + + def __call__(self, x): + return logmelspectrogram( + x, + fs=self.fs, + n_mels=self.n_mels, + n_fft=self.n_fft, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, + ) + + +class Stft2LogMelSpectrogram(): + def __init__(self, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10): + self.fs = fs + self.n_mels = n_mels + self.n_fft = n_fft + self.fmin = fmin + self.fmax = fmax + self.eps = eps + + def __repr__(self): + return ( + "{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, " + "fmin={fmin}, fmax={fmax}, eps={eps}))".format( + name=self.__class__.__name__, + fs=self.fs, + n_mels=self.n_mels, + n_fft=self.n_fft, + fmin=self.fmin, + fmax=self.fmax, + eps=self.eps, + ) + ) + + def __call__(self, x): + return stft2logmelspectrogram( + x, + fs=self.fs, + n_mels=self.n_mels, + n_fft=self.n_fft, + fmin=self.fmin, + fmax=self.fmax, + ) + + +class Stft(): + def __init__( + self, + n_fft, + n_shift, + win_length=None, + window="hann", + center=True, + pad_mode="reflect", + ): + self.n_fft = n_fft + self.n_shift = n_shift + self.win_length = win_length + self.window = window + self.center = center + self.pad_mode = pad_mode + + def __repr__(self): + return ( + "{name}(n_fft={n_fft}, n_shift={n_shift}, " + "win_length={win_length}, window={window}," + "center={center}, pad_mode={pad_mode})".format( + name=self.__class__.__name__, + n_fft=self.n_fft, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, + center=self.center, + pad_mode=self.pad_mode, + ) + ) + + def __call__(self, x): + return stft( + x, + self.n_fft, + self.n_shift, + win_length=self.win_length, + window=self.window, + center=self.center, + pad_mode=self.pad_mode, + ) + + +class IStft(): + def __init__(self, n_shift, win_length=None, window="hann", center=True): + self.n_shift = n_shift + self.win_length = win_length + self.window = window + self.center = center + + def __repr__(self): + return ( + "{name}(n_shift={n_shift}, " + "win_length={win_length}, window={window}," + "center={center})".format( + name=self.__class__.__name__, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, + center=self.center, + ) + ) + + def __call__(self, x): + return istft( + x, + self.n_shift, + win_length=self.win_length, + window=self.window, + center=self.center, + ) diff --git a/deepspeech/transform/transform_interface.py b/deepspeech/transform/transform_interface.py new file mode 100644 index 00000000..8a6aba45 --- /dev/null +++ b/deepspeech/transform/transform_interface.py @@ -0,0 +1,20 @@ +# TODO(karita): add this to all the transform impl. +class TransformInterface: + """Transform Interface""" + + def __call__(self, x): + raise NotImplementedError("__call__ method is not implemented") + + @classmethod + def add_arguments(cls, parser): + return parser + + def __repr__(self): + return self.__class__.__name__ + "()" + + +class Identity(TransformInterface): + """Identity Function""" + + def __call__(self, x): + return x diff --git a/deepspeech/transform/transformation.py b/deepspeech/transform/transformation.py new file mode 100644 index 00000000..0f8c39bb --- /dev/null +++ b/deepspeech/transform/transformation.py @@ -0,0 +1,149 @@ +"""Transformation module.""" +from collections.abc import Sequence +from collections import OrderedDict +import copy +from inspect import signature +import io +import logging + +import yaml + +from deepspeech.utils.dynamic_import import dynamic_import + + +# TODO(karita): inherit TransformInterface +# TODO(karita): register cmd arguments in asr_train.py +import_alias = dict( + identity="deepspeech.transform.transform_interface:Identity", + time_warp="deepspeech.transform.spec_augment:TimeWarp", + time_mask="deepspeech.transform.spec_augment:TimeMask", + freq_mask="deepspeech.transform.spec_augment:FreqMask", + spec_augment="deepspeech.transform.spec_augment:SpecAugment", + speed_perturbation="deepspeech.transform.perturb:SpeedPerturbation", + volume_perturbation="deepspeech.transform.perturb:VolumePerturbation", + noise_injection="deepspeech.transform.perturb:NoiseInjection", + bandpass_perturbation="deepspeech.transform.perturb:BandpassPerturbation", + rir_convolve="deepspeech.transform.perturb:RIRConvolve", + delta="deepspeech.transform.add_deltas:AddDeltas", + cmvn="deepspeech.transform.cmvn:CMVN", + utterance_cmvn="deepspeech.transform.cmvn:UtteranceCMVN", + fbank="deepspeech.transform.spectrogram:LogMelSpectrogram", + spectrogram="deepspeech.transform.spectrogram:Spectrogram", + stft="deepspeech.transform.spectrogram:Stft", + istft="deepspeech.transform.spectrogram:IStft", + stft2fbank="deepspeech.transform.spectrogram:Stft2LogMelSpectrogram", + wpe="deepspeech.transform.wpe:WPE", + channel_selector="deepspeech.transform.channel_selector:ChannelSelector", +) + + +class Transformation(): + """Apply some functions to the mini-batch + + Examples: + >>> kwargs = {"process": [{"type": "fbank", + ... "n_mels": 80, + ... "fs": 16000}, + ... {"type": "cmvn", + ... "stats": "data/train/cmvn.ark", + ... "norm_vars": True}, + ... {"type": "delta", "window": 2, "order": 2}]} + >>> transform = Transformation(kwargs) + >>> bs = 10 + >>> xs = [np.random.randn(100, 80).astype(np.float32) + ... for _ in range(bs)] + >>> xs = transform(xs) + """ + + def __init__(self, conffile=None): + if conffile is not None: + if isinstance(conffile, dict): + self.conf = copy.deepcopy(conffile) + else: + with io.open(conffile, encoding="utf-8") as f: + self.conf = yaml.safe_load(f) + assert isinstance(self.conf, dict), type(self.conf) + else: + self.conf = {"mode": "sequential", "process": []} + + self.functions = OrderedDict() + if self.conf.get("mode", "sequential") == "sequential": + for idx, process in enumerate(self.conf["process"]): + assert isinstance(process, dict), type(process) + opts = dict(process) + process_type = opts.pop("type") + class_obj = dynamic_import(process_type, import_alias) + # TODO(karita): assert issubclass(class_obj, TransformInterface) + try: + self.functions[idx] = class_obj(**opts) + except TypeError: + try: + signa = signature(class_obj) + except ValueError: + # Some function, e.g. built-in function, are failed + pass + else: + logging.error( + "Expected signature: {}({})".format( + class_obj.__name__, signa + ) + ) + raise + else: + raise NotImplementedError( + "Not supporting mode={}".format(self.conf["mode"]) + ) + + def __repr__(self): + rep = "\n" + "\n".join( + " {}: {}".format(k, v) for k, v in self.functions.items() + ) + return "{}({})".format(self.__class__.__name__, rep) + + def __call__(self, xs, uttid_list=None, **kwargs): + """Return new mini-batch + + :param Union[Sequence[np.ndarray], np.ndarray] xs: + :param Union[Sequence[str], str] uttid_list: + :return: batch: + :rtype: List[np.ndarray] + """ + if not isinstance(xs, Sequence): + is_batch = False + xs = [xs] + else: + is_batch = True + + if isinstance(uttid_list, str): + uttid_list = [uttid_list for _ in range(len(xs))] + + if self.conf.get("mode", "sequential") == "sequential": + for idx in range(len(self.conf["process"])): + func = self.functions[idx] + # TODO(karita): use TrainingTrans and UttTrans to check __call__ args + # Derive only the args which the func has + try: + param = signature(func).parameters + except ValueError: + # Some function, e.g. built-in function, are failed + param = {} + _kwargs = {k: v for k, v in kwargs.items() if k in param} + try: + if uttid_list is not None and "uttid" in param: + xs = [func(x, u, **_kwargs) for x, u in zip(xs, uttid_list)] + else: + xs = [func(x, **_kwargs) for x in xs] + except Exception: + logging.fatal( + "Catch a exception from {}th func: {}".format(idx, func) + ) + raise + else: + raise NotImplementedError( + "Not supporting mode={}".format(self.conf["mode"]) + ) + + if is_batch: + return xs + else: + return xs[0] diff --git a/deepspeech/transform/wpe.py b/deepspeech/transform/wpe.py new file mode 100644 index 00000000..8aed97e6 --- /dev/null +++ b/deepspeech/transform/wpe.py @@ -0,0 +1,45 @@ +from nara_wpe.wpe import wpe + + +class WPE(object): + def __init__( + self, taps=10, delay=3, iterations=3, psd_context=0, statistics_mode="full" + ): + self.taps = taps + self.delay = delay + self.iterations = iterations + self.psd_context = psd_context + self.statistics_mode = statistics_mode + + def __repr__(self): + return ( + "{name}(taps={taps}, delay={delay}" + "iterations={iterations}, psd_context={psd_context}, " + "statistics_mode={statistics_mode})".format( + name=self.__class__.__name__, + taps=self.taps, + delay=self.delay, + iterations=self.iterations, + psd_context=self.psd_context, + statistics_mode=self.statistics_mode, + ) + ) + + def __call__(self, xs): + """Return enhanced + + :param np.ndarray xs: (Time, Channel, Frequency) + :return: enhanced_xs + :rtype: np.ndarray + + """ + # nara_wpe.wpe: (F, C, T) + xs = wpe( + xs.transpose((2, 1, 0)), + taps=self.taps, + delay=self.delay, + iterations=self.iterations, + psd_context=self.psd_context, + statistics_mode=self.statistics_mode, + ) + return xs.transpose(2, 1, 0) diff --git a/deepspeech/utils/asr_utils.py b/deepspeech/utils/asr_utils.py new file mode 100644 index 00000000..6f86e56f --- /dev/null +++ b/deepspeech/utils/asr_utils.py @@ -0,0 +1,52 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json + +import numpy as np + +__all__ = ["label_smoothing_dist"] + + +# TODO(takaaki-hori): add different smoothing methods +def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0): + """Obtain label distribution for loss smoothing. + + :param odim: + :param lsm_type: + :param blank: + :param transcript: + :return: + """ + if transcript is not None: + with open(transcript, "rb") as f: + trans_json = json.load(f)["utts"] + + if lsm_type == "unigram": + assert transcript is not None, ( + "transcript is required for %s label smoothing" % lsm_type) + labelcount = np.zeros(odim) + for k, v in trans_json.items(): + ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()]) + # to avoid an error when there is no text in an uttrance + if len(ids) > 0: + labelcount[ids] += 1 + labelcount[odim - 1] = len(transcript) # count + labelcount[labelcount == 0] = 1 # flooring + labelcount[blank] = 0 # remove counts for blank + labeldist = labelcount.astype(np.float32) / np.sum(labelcount) + else: + logging.error("Error: unexpected label smoothing type: %s" % lsm_type) + sys.exit() + + return labeldist diff --git a/deepspeech/utils/bleu_score.py b/deepspeech/utils/bleu_score.py index 09646133..ea32fcf9 100644 --- a/deepspeech/utils/bleu_score.py +++ b/deepspeech/utils/bleu_score.py @@ -14,17 +14,17 @@ """This module provides functions to calculate bleu score in different level. e.g. wer for word-level, cer for char-level. """ +import nltk +import numpy as np import sacrebleu -__all__ = ['bleu', 'char_bleu'] +__all__ = ['bleu', 'char_bleu', "ErrorCalculator"] def bleu(hypothesis, reference): """Calculate BLEU. BLEU compares reference text and hypothesis text in word-level using scarebleu. - - :param reference: The reference sentences. :type reference: list[list[str]] :param hypothesis: The hypothesis sentence. @@ -39,8 +39,6 @@ def char_bleu(hypothesis, reference): """Calculate BLEU. BLEU compares reference text and hypothesis text in char-level using scarebleu. - - :param reference: The reference sentences. :type reference: list[list[str]] :param hypothesis: The hypothesis sentence. @@ -52,3 +50,70 @@ def char_bleu(hypothesis, reference): for ref in reference] return sacrebleu.corpus_bleu(hypothesis, reference) + + +class ErrorCalculator(): + """Calculate BLEU for ST and MT models during training. + + :param y_hats: numpy array with predicted text + :param y_pads: numpy array with true (target) text + :param char_list: vocabulary list + :param sym_space: space symbol + :param sym_pad: pad symbol + :param report_bleu: report BLUE score if True + """ + + def __init__(self, char_list, sym_space, sym_pad, report_bleu=False): + """Construct an ErrorCalculator object.""" + super().__init__() + self.char_list = char_list + self.space = sym_space + self.pad = sym_pad + self.report_bleu = report_bleu + if self.space in self.char_list: + self.idx_space = self.char_list.index(self.space) + else: + self.idx_space = None + + def __call__(self, ys_hat, ys_pad): + """Calculate corpus-level BLEU score. + + :param torch.Tensor ys_hat: prediction (batch, seqlen) + :param torch.Tensor ys_pad: reference (batch, seqlen) + :return: corpus-level BLEU score in a mini-batch + :rtype float + """ + bleu = None + if not self.report_bleu: + return bleu + + bleu = self.calculate_corpus_bleu(ys_hat, ys_pad) + return bleu + + def calculate_corpus_bleu(self, ys_hat, ys_pad): + """Calculate corpus-level BLEU score in a mini-batch. + + :param torch.Tensor seqs_hat: prediction (batch, seqlen) + :param torch.Tensor seqs_true: reference (batch, seqlen) + :return: corpus-level BLEU score + :rtype float + """ + seqs_hat, seqs_true = [], [] + for i, y_hat in enumerate(ys_hat): + y_true = ys_pad[i] + eos_true = np.where(y_true == -1)[0] + ymax = eos_true[0] if len(eos_true) > 0 else len(y_true) + # NOTE: padding index (-1) in y_true is used to pad y_hat + # because y_hats is not padded with -1 + seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]] + seq_true = [ + self.char_list[int(idx)] for idx in y_true if int(idx) != -1 + ] + seq_hat_text = "".join(seq_hat).replace(self.space, " ") + seq_hat_text = seq_hat_text.replace(self.pad, "") + seq_true_text = "".join(seq_true).replace(self.space, " ") + seqs_hat.append(seq_hat_text) + seqs_true.append(seq_true_text) + bleu = nltk.bleu_score.corpus_bleu([[ref] for ref in seqs_true], + seqs_hat) + return bleu * 100 diff --git a/deepspeech/utils/check_kwargs.py b/deepspeech/utils/check_kwargs.py new file mode 100644 index 00000000..593bfa24 --- /dev/null +++ b/deepspeech/utils/check_kwargs.py @@ -0,0 +1,20 @@ +import inspect + + +def check_kwargs(func, kwargs, name=None): + """check kwargs are valid for func + + If kwargs are invalid, raise TypeError as same as python default + :param function func: function to be validated + :param dict kwargs: keyword arguments for func + :param str name: name used in TypeError (default is func name) + """ + try: + params = inspect.signature(func).parameters + except ValueError: + return + if name is None: + name = func.__name__ + for k in kwargs.keys(): + if k not in params: + raise TypeError(f"{name}() got an unexpected keyword argument '{k}'") diff --git a/deepspeech/utils/cli_readers.py b/deepspeech/utils/cli_readers.py new file mode 100644 index 00000000..72aa2bdb --- /dev/null +++ b/deepspeech/utils/cli_readers.py @@ -0,0 +1,241 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io +import logging +import sys + +import h5py +import kaldiio +import soundfile + +from deepspeech.io.reader import SoundHDF5File + + +def file_reader_helper( + rspecifier: str, + filetype: str="mat", + return_shape: bool=False, + segments: str=None, ): + """Read uttid and array in kaldi style + + This function might be a bit confusing as "ark" is used + for HDF5 to imitate "kaldi-rspecifier". + + Args: + rspecifier: Give as "ark:feats.ark" or "scp:feats.scp" + filetype: "mat" is kaldi-martix, "hdf5": HDF5 + return_shape: Return the shape of the matrix, + instead of the matrix. This can reduce IO cost for HDF5. + segments (str): The file format is + " \n" + "e.g. call-861225-A-0050-0065 call-861225-A 5.0 6.5\n" + Returns: + Generator[Tuple[str, np.ndarray], None, None]: + + Examples: + Read from kaldi-matrix ark file: + + >>> for u, array in file_reader_helper('ark:feats.ark', 'mat'): + ... array + + Read from HDF5 file: + + >>> for u, array in file_reader_helper('ark:feats.h5', 'hdf5'): + ... array + + """ + if filetype == "mat": + return KaldiReader( + rspecifier, return_shape=return_shape, segments=segments) + elif filetype == "hdf5": + return HDF5Reader(rspecifier, return_shape=return_shape) + elif filetype == "sound.hdf5": + return SoundHDF5Reader(rspecifier, return_shape=return_shape) + elif filetype == "sound": + return SoundReader(rspecifier, return_shape=return_shape) + else: + raise NotImplementedError(f"filetype={filetype}") + + +class KaldiReader: + def __init__(self, rspecifier, return_shape=False, segments=None): + self.rspecifier = rspecifier + self.return_shape = return_shape + self.segments = segments + + def __iter__(self): + with kaldiio.ReadHelper( + self.rspecifier, segments=self.segments) as reader: + for key, array in reader: + if self.return_shape: + array = array.shape + yield key, array + + +class HDF5Reader: + def __init__(self, rspecifier, return_shape=False): + if ":" not in rspecifier: + raise ValueError('Give "rspecifier" such as "ark:some.ark: {}"'. + format(self.rspecifier)) + self.rspecifier = rspecifier + self.ark_or_scp, self.filepath = self.rspecifier.split(":", 1) + if self.ark_or_scp not in ["ark", "scp"]: + raise ValueError(f"Must be scp or ark: {self.ark_or_scp}") + + self.return_shape = return_shape + + def __iter__(self): + if self.ark_or_scp == "scp": + hdf5_dict = {} + with open(self.filepath, "r", encoding="utf-8") as f: + for line in f: + key, value = line.rstrip().split(None, 1) + + if ":" not in value: + raise RuntimeError( + "scp file for hdf5 should be like: " + '"uttid filepath.h5:key": {}({})'.format( + line, self.filepath)) + path, h5_key = value.split(":", 1) + + hdf5_file = hdf5_dict.get(path) + if hdf5_file is None: + try: + hdf5_file = h5py.File(path, "r") + except Exception: + logging.error("Error when loading {}".format(path)) + raise + hdf5_dict[path] = hdf5_file + + try: + data = hdf5_file[h5_key] + except Exception: + logging.error("Error when loading {} with key={}". + format(path, h5_key)) + raise + + if self.return_shape: + yield key, data.shape + else: + yield key, data[()] + + # Closing all files + for k in hdf5_dict: + try: + hdf5_dict[k].close() + except Exception: + pass + + else: + if self.filepath == "-": + # Required h5py>=2.9 + filepath = io.BytesIO(sys.stdin.buffer.read()) + else: + filepath = self.filepath + with h5py.File(filepath, "r") as f: + for key in f: + if self.return_shape: + yield key, f[key].shape + else: + yield key, f[key][()] + + +class SoundHDF5Reader: + def __init__(self, rspecifier, return_shape=False): + if ":" not in rspecifier: + raise ValueError('Give "rspecifier" such as "ark:some.ark: {}"'. + format(rspecifier)) + self.ark_or_scp, self.filepath = rspecifier.split(":", 1) + if self.ark_or_scp not in ["ark", "scp"]: + raise ValueError(f"Must be scp or ark: {self.ark_or_scp}") + self.return_shape = return_shape + + def __iter__(self): + if self.ark_or_scp == "scp": + hdf5_dict = {} + with open(self.filepath, "r", encoding="utf-8") as f: + for line in f: + key, value = line.rstrip().split(None, 1) + + if ":" not in value: + raise RuntimeError( + "scp file for hdf5 should be like: " + '"uttid filepath.h5:key": {}({})'.format( + line, self.filepath)) + path, h5_key = value.split(":", 1) + + hdf5_file = hdf5_dict.get(path) + if hdf5_file is None: + try: + hdf5_file = SoundHDF5File(path, "r") + except Exception: + logging.error("Error when loading {}".format(path)) + raise + hdf5_dict[path] = hdf5_file + + try: + data = hdf5_file[h5_key] + except Exception: + logging.error("Error when loading {} with key={}". + format(path, h5_key)) + raise + + # Change Tuple[ndarray, int] -> Tuple[int, ndarray] + # (soundfile style -> scipy style) + array, rate = data + if self.return_shape: + array = array.shape + yield key, (rate, array) + + # Closing all files + for k in hdf5_dict: + try: + hdf5_dict[k].close() + except Exception: + pass + + else: + if self.filepath == "-": + # Required h5py>=2.9 + filepath = io.BytesIO(sys.stdin.buffer.read()) + else: + filepath = self.filepath + for key, (a, r) in SoundHDF5File(filepath, "r").items(): + if self.return_shape: + a = a.shape + yield key, (r, a) + + +class SoundReader: + def __init__(self, rspecifier, return_shape=False): + if ":" not in rspecifier: + raise ValueError('Give "rspecifier" such as "scp:some.scp: {}"'. + format(rspecifier)) + self.ark_or_scp, self.filepath = rspecifier.split(":", 1) + if self.ark_or_scp != "scp": + raise ValueError('Only supporting "scp" for sound file: {}'.format( + self.ark_or_scp)) + self.return_shape = return_shape + + def __iter__(self): + with open(self.filepath, "r", encoding="utf-8") as f: + for line in f: + key, sound_file_path = line.rstrip().split(None, 1) + # Assume PCM16 + array, rate = soundfile.read(sound_file_path, dtype="int16") + # Change Tuple[ndarray, int] -> Tuple[int, ndarray] + # (soundfile style -> scipy style) + if self.return_shape: + array = array.shape + yield key, (rate, array) diff --git a/deepspeech/utils/cli_utils.py b/deepspeech/utils/cli_utils.py new file mode 100644 index 00000000..f8e1d60b --- /dev/null +++ b/deepspeech/utils/cli_utils.py @@ -0,0 +1,70 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +from collections.abc import Sequence +from distutils.util import strtobool as dist_strtobool + +import numpy + + +def strtobool(x): + # distutils.util.strtobool returns integer, but it's confusing, + return bool(dist_strtobool(x)) + + +def get_commandline_args(): + extra_chars = [ + " ", + ";", + "&", + "(", + ")", + "|", + "^", + "<", + ">", + "?", + "*", + "[", + "]", + "$", + "`", + '"', + "\\", + "!", + "{", + "}", + ] + + # Escape the extra characters for shell + argv = [ + arg.replace("'", "'\\''") if all(char not in arg + for char in extra_chars) else + "'" + arg.replace("'", "'\\''") + "'" for arg in sys.argv + ] + + return sys.executable + " " + " ".join(argv) + + +def is_scipy_wav_style(value): + # If Tuple[int, numpy.ndarray] or not + return (isinstance(value, Sequence) and len(value) == 2 and + isinstance(value[0], int) and isinstance(value[1], numpy.ndarray)) + + +def assert_scipy_wav_style(value): + assert is_scipy_wav_style( + value), "Must be Tuple[int, numpy.ndarray], but got {}".format( + type(value) if not isinstance(value, Sequence) else "{}[{}]".format( + type(value), ", ".join(str(type(v)) for v in value))) diff --git a/deepspeech/utils/cli_writers.py b/deepspeech/utils/cli_writers.py new file mode 100644 index 00000000..e0737193 --- /dev/null +++ b/deepspeech/utils/cli_writers.py @@ -0,0 +1,293 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path +from typing import Dict + +import h5py +import kaldiio +import numpy +import soundfile + +from deepspeech.io.reader import SoundHDF5File +from deepspeech.utils.cli_utils import assert_scipy_wav_style + + +def file_writer_helper( + wspecifier: str, + filetype: str="mat", + write_num_frames: str=None, + compress: bool=False, + compression_method: int=2, + pcm_format: str="wav", ): + """Write matrices in kaldi style + + Args: + wspecifier: e.g. ark,scp:out.ark,out.scp + filetype: "mat" is kaldi-martix, "hdf5": HDF5 + write_num_frames: e.g. 'ark,t:num_frames.txt' + compress: Compress or not + compression_method: Specify compression level + + Write in kaldi-matrix-ark with "kaldi-scp" file: + + >>> with file_writer_helper('ark,scp:out.ark,out.scp') as f: + >>> f['uttid'] = array + + This "scp" has the following format: + + uttidA out.ark:1234 + uttidB out.ark:2222 + + where, 1234 and 2222 points the strating byte address of the matrix. + (For detail, see official documentation of Kaldi) + + Write in HDF5 with "scp" file: + + >>> with file_writer_helper('ark,scp:out.h5,out.scp', 'hdf5') as f: + >>> f['uttid'] = array + + This "scp" file is created as: + + uttidA out.h5:uttidA + uttidB out.h5:uttidB + + HDF5 can be, unlike "kaldi-ark", accessed to any keys, + so originally "scp" is not required for random-reading. + Nevertheless we create "scp" for HDF5 because it is useful + for some use-case. e.g. Concatenation, Splitting. + + """ + if filetype == "mat": + return KaldiWriter( + wspecifier, + write_num_frames=write_num_frames, + compress=compress, + compression_method=compression_method, ) + elif filetype == "hdf5": + return HDF5Writer( + wspecifier, write_num_frames=write_num_frames, compress=compress) + elif filetype == "sound.hdf5": + return SoundHDF5Writer( + wspecifier, + write_num_frames=write_num_frames, + pcm_format=pcm_format) + elif filetype == "sound": + return SoundWriter( + wspecifier, + write_num_frames=write_num_frames, + pcm_format=pcm_format) + else: + raise NotImplementedError(f"filetype={filetype}") + + +class BaseWriter: + def __setitem__(self, key, value): + raise NotImplementedError + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def close(self): + try: + self.writer.close() + except Exception: + pass + + if self.writer_scp is not None: + try: + self.writer_scp.close() + except Exception: + pass + + if self.writer_nframe is not None: + try: + self.writer_nframe.close() + except Exception: + pass + + +def get_num_frames_writer(write_num_frames: str): + """get_num_frames_writer + + Examples: + >>> get_num_frames_writer('ark,t:num_frames.txt') + """ + if write_num_frames is not None: + if ":" not in write_num_frames: + raise ValueError('Must include ":", write_num_frames={}'.format( + write_num_frames)) + + nframes_type, nframes_file = write_num_frames.split(":", 1) + if nframes_type != "ark,t": + raise ValueError("Only supporting text mode. " + "e.g. --write-num-frames=ark,t:foo.txt :" + "{}".format(nframes_type)) + + return open(nframes_file, "w", encoding="utf-8") + + +class KaldiWriter(BaseWriter): + def __init__(self, + wspecifier, + write_num_frames=None, + compress=False, + compression_method=2): + if compress: + self.writer = kaldiio.WriteHelper( + wspecifier, compression_method=compression_method) + else: + self.writer = kaldiio.WriteHelper(wspecifier) + self.writer_scp = None + if write_num_frames is not None: + self.writer_nframe = get_num_frames_writer(write_num_frames) + else: + self.writer_nframe = None + + def __setitem__(self, key, value): + self.writer[key] = value + if self.writer_nframe is not None: + self.writer_nframe.write(f"{key} {len(value)}\n") + + +def parse_wspecifier(wspecifier: str) -> Dict[str, str]: + """Parse wspecifier to dict + + Examples: + >>> parse_wspecifier('ark,scp:out.ark,out.scp') + {'ark': 'out.ark', 'scp': 'out.scp'} + + """ + ark_scp, filepath = wspecifier.split(":", 1) + if ark_scp not in ["ark", "scp,ark", "ark,scp"]: + raise ValueError("{} is not allowed: {}".format(ark_scp, wspecifier)) + ark_scps = ark_scp.split(",") + filepaths = filepath.split(",") + if len(ark_scps) != len(filepaths): + raise ValueError("Mismatch: {} and {}".format(ark_scp, filepath)) + spec_dict = dict(zip(ark_scps, filepaths)) + return spec_dict + + +class HDF5Writer(BaseWriter): + """HDF5Writer + + Examples: + >>> with HDF5Writer('ark:out.h5', compress=True) as f: + ... f['key'] = array + """ + + def __init__(self, wspecifier, write_num_frames=None, compress=False): + spec_dict = parse_wspecifier(wspecifier) + self.filename = spec_dict["ark"] + + if compress: + self.kwargs = {"compression": "gzip"} + else: + self.kwargs = {} + self.writer = h5py.File(spec_dict["ark"], "w") + if "scp" in spec_dict: + self.writer_scp = open(spec_dict["scp"], "w", encoding="utf-8") + else: + self.writer_scp = None + if write_num_frames is not None: + self.writer_nframe = get_num_frames_writer(write_num_frames) + else: + self.writer_nframe = None + + def __setitem__(self, key, value): + self.writer.create_dataset(key, data=value, **self.kwargs) + + if self.writer_scp is not None: + self.writer_scp.write(f"{key} {self.filename}:{key}\n") + if self.writer_nframe is not None: + self.writer_nframe.write(f"{key} {len(value)}\n") + + +class SoundHDF5Writer(BaseWriter): + """SoundHDF5Writer + + Examples: + >>> fs = 16000 + >>> with SoundHDF5Writer('ark:out.h5') as f: + ... f['key'] = fs, array + """ + + def __init__(self, wspecifier, write_num_frames=None, pcm_format="wav"): + self.pcm_format = pcm_format + spec_dict = parse_wspecifier(wspecifier) + self.filename = spec_dict["ark"] + self.writer = SoundHDF5File( + spec_dict["ark"], "w", format=self.pcm_format) + if "scp" in spec_dict: + self.writer_scp = open(spec_dict["scp"], "w", encoding="utf-8") + else: + self.writer_scp = None + if write_num_frames is not None: + self.writer_nframe = get_num_frames_writer(write_num_frames) + else: + self.writer_nframe = None + + def __setitem__(self, key, value): + assert_scipy_wav_style(value) + # Change Tuple[int, ndarray] -> Tuple[ndarray, int] + # (scipy style -> soundfile style) + value = (value[1], value[0]) + self.writer.create_dataset(key, data=value) + + if self.writer_scp is not None: + self.writer_scp.write(f"{key} {self.filename}:{key}\n") + if self.writer_nframe is not None: + self.writer_nframe.write(f"{key} {len(value[0])}\n") + + +class SoundWriter(BaseWriter): + """SoundWriter + + Examples: + >>> fs = 16000 + >>> with SoundWriter('ark,scp:outdir,out.scp') as f: + ... f['key'] = fs, array + """ + + def __init__(self, wspecifier, write_num_frames=None, pcm_format="wav"): + self.pcm_format = pcm_format + spec_dict = parse_wspecifier(wspecifier) + # e.g. ark,scp:dirname,wav.scp + # -> The wave files are found in dirname/*.wav + self.dirname = spec_dict["ark"] + Path(self.dirname).mkdir(parents=True, exist_ok=True) + self.writer = None + + if "scp" in spec_dict: + self.writer_scp = open(spec_dict["scp"], "w", encoding="utf-8") + else: + self.writer_scp = None + if write_num_frames is not None: + self.writer_nframe = get_num_frames_writer(write_num_frames) + else: + self.writer_nframe = None + + def __setitem__(self, key, value): + assert_scipy_wav_style(value) + rate, signal = value + wavfile = Path(self.dirname) / (key + "." + self.pcm_format) + soundfile.write(wavfile, signal.astype(numpy.int16), rate) + + if self.writer_scp is not None: + self.writer_scp.write(f"{key} {wavfile}\n") + if self.writer_nframe is not None: + self.writer_nframe.write(f"{key} {len(signal)}\n") diff --git a/deepspeech/utils/error_rate.py b/deepspeech/utils/error_rate.py index 81f458b6..548376aa 100644 --- a/deepspeech/utils/error_rate.py +++ b/deepspeech/utils/error_rate.py @@ -14,12 +14,12 @@ """This module provides functions to calculate error rate in different level. e.g. wer for word-level, cer for char-level. """ +from itertools import groupby + import editdistance import numpy as np -__all__ = ['word_errors', 'char_errors', 'wer', 'cer'] - -editdistance.eval("a", "b") +__all__ = ['word_errors', 'char_errors', 'wer', 'cer', "ErrorCalculator"] def _levenshtein_distance(ref, hyp): @@ -211,3 +211,154 @@ def cer(reference, hypothesis, ignore_case=False, remove_space=False): cer = float(edit_distance) / ref_len return cer + + +class ErrorCalculator(): + """Calculate CER and WER for E2E_ASR and CTC models during training. + + :param y_hats: numpy array with predicted text + :param y_pads: numpy array with true (target) text + :param char_list: List[str] + :param sym_space: + :param sym_blank: + :return: + """ + + def __init__(self, + char_list, + sym_space, + sym_blank, + report_cer=False, + report_wer=False): + """Construct an ErrorCalculator object.""" + super().__init__() + + self.report_cer = report_cer + self.report_wer = report_wer + + self.char_list = char_list + self.space = sym_space + self.blank = sym_blank + self.idx_blank = self.char_list.index(self.blank) + if self.space in self.char_list: + self.idx_space = self.char_list.index(self.space) + else: + self.idx_space = None + + def __call__(self, ys_hat, ys_pad, is_ctc=False): + """Calculate sentence-level WER/CER score. + + :param paddle.Tensor ys_hat: prediction (batch, seqlen) + :param paddle.Tensor ys_pad: reference (batch, seqlen) + :param bool is_ctc: calculate CER score for CTC + :return: sentence-level WER score + :rtype float + :return: sentence-level CER score + :rtype float + """ + cer, wer = None, None + if is_ctc: + return self.calculate_cer_ctc(ys_hat, ys_pad) + elif not self.report_cer and not self.report_wer: + return cer, wer + + seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad) + if self.report_cer: + cer = self.calculate_cer(seqs_hat, seqs_true) + + if self.report_wer: + wer = self.calculate_wer(seqs_hat, seqs_true) + return cer, wer + + def calculate_cer_ctc(self, ys_hat, ys_pad): + """Calculate sentence-level CER score for CTC. + + :param paddle.Tensor ys_hat: prediction (batch, seqlen) + :param paddle.Tensor ys_pad: reference (batch, seqlen) + :return: average sentence-level CER score + :rtype float + """ + cers, char_ref_lens = [], [] + for i, y in enumerate(ys_hat): + y_hat = [x[0] for x in groupby(y)] + y_true = ys_pad[i] + seq_hat, seq_true = [], [] + for idx in y_hat: + idx = int(idx) + if idx != -1 and idx != self.idx_blank and idx != self.idx_space: + seq_hat.append(self.char_list[int(idx)]) + + for idx in y_true: + idx = int(idx) + if idx != -1 and idx != self.idx_blank and idx != self.idx_space: + seq_true.append(self.char_list[int(idx)]) + + hyp_chars = "".join(seq_hat) + ref_chars = "".join(seq_true) + if len(ref_chars) > 0: + cers.append(editdistance.eval(hyp_chars, ref_chars)) + char_ref_lens.append(len(ref_chars)) + + cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None + return cer_ctc + + def convert_to_char(self, ys_hat, ys_pad): + """Convert index to character. + + :param paddle.Tensor seqs_hat: prediction (batch, seqlen) + :param paddle.Tensor seqs_true: reference (batch, seqlen) + :return: token list of prediction + :rtype list + :return: token list of reference + :rtype list + """ + seqs_hat, seqs_true = [], [] + for i, y_hat in enumerate(ys_hat): + y_true = ys_pad[i] + eos_true = np.where(y_true == -1)[0] + ymax = eos_true[0] if len(eos_true) > 0 else len(y_true) + # NOTE: padding index (-1) in y_true is used to pad y_hat + seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]] + seq_true = [ + self.char_list[int(idx)] for idx in y_true if int(idx) != -1 + ] + seq_hat_text = "".join(seq_hat).replace(self.space, " ") + seq_hat_text = seq_hat_text.replace(self.blank, "") + seq_true_text = "".join(seq_true).replace(self.space, " ") + seqs_hat.append(seq_hat_text) + seqs_true.append(seq_true_text) + return seqs_hat, seqs_true + + def calculate_cer(self, seqs_hat, seqs_true): + """Calculate sentence-level CER score. + + :param list seqs_hat: prediction + :param list seqs_true: reference + :return: average sentence-level CER score + :rtype float + """ + char_eds, char_ref_lens = [], [] + for i, seq_hat_text in enumerate(seqs_hat): + seq_true_text = seqs_true[i] + hyp_chars = seq_hat_text.replace(" ", "") + ref_chars = seq_true_text.replace(" ", "") + char_eds.append(editdistance.eval(hyp_chars, ref_chars)) + char_ref_lens.append(len(ref_chars)) + return float(sum(char_eds)) / sum(char_ref_lens) + + def calculate_wer(self, seqs_hat, seqs_true): + """Calculate sentence-level WER score. + + :param list seqs_hat: prediction + :param list seqs_true: reference + :return: average sentence-level WER score + :rtype float + """ + word_eds, word_ref_lens = [], [] + for i, seq_hat_text in enumerate(seqs_hat): + seq_true_text = seqs_true[i] + hyp_words = seq_hat_text.split() + ref_words = seq_true_text.split() + word_eds.append(editdistance.eval(hyp_words, ref_words)) + word_ref_lens.append(len(ref_words)) + return float(sum(word_eds)) / sum(word_ref_lens) diff --git a/deepspeech/utils/spec_augment.py b/deepspeech/utils/spec_augment.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/librispeech/s2/.gitignore b/examples/librispeech/s2/.gitignore new file mode 100644 index 00000000..e56b7d34 --- /dev/null +++ b/examples/librispeech/s2/.gitignore @@ -0,0 +1,4 @@ +dump +fbank +exp +data diff --git a/examples/librispeech/s2/README.md b/examples/librispeech/s2/README.md index d5df37d8..fc634ff6 100644 --- a/examples/librispeech/s2/README.md +++ b/examples/librispeech/s2/README.md @@ -1,8 +1,11 @@ # LibriSpeech -| Model | Params | Config | Augmentation| Loss | -| --- | --- | --- | --- | -| transformer | 32.52 M | conf/transformer.yaml | spec_aug | 6.3197922706604 | + +## Transformer + +| Model | Params | GPUS | Averaged Model | Config | Augmentation| Loss | +| --- | --- | --- | --- | --- | --- | +| transformer | 32.52 M | 8 Tesla V100-SXM2-32GB | 10-best val_loss | conf/transformer.yaml | spec_aug | 6.3197922706604 | | Test Set | Decode Method | #Snt | #Wrd | Corr | Sub | Del | Ins | Err | S.Err | @@ -11,4 +14,14 @@ | test-clean | ctc_greedy_search | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 48.0 | | test-clean | ctc_prefix_beamsearch | 2620 | 52576 | 95.9 | 3.7 | 0.4 | 0.5 | 4.6 | 47.6 | | test-clean | attention_rescore | 2620 | 52576 | 96.8 | 2.9 | 0.3 | 0.4 | 3.7 | 38.0 | + +### JoinCTC + +| Test Set | Decode Method | #Snt | #Wrd | Corr | Sub | Del | Ins | Err | S.Err | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| test-clean | join_ctc_only_att | 2620 | 52576 | 96.1 | 2.5 | 1.4 | 0.4 | 4.4 | 34.7 | | test-clean | join_ctc_w/o_lm | 2620 | 52576 | 97.2 | 2.6 | 0.3 | 0.4 | 3.2 | 34.9 | +| test-clean | join_ctc_w_lm | 2620 | 52576 | 97.9 | 1.8 | 0.2 | 0.3 | 2.4 | 27.8 | + +Compare with [ESPNET](https://github.com/espnet/espnet/blob/master/egs/librispeech/asr1/RESULTS.md#pytorch-large-transformer-with-specaug-4-gpus--transformer-lm-4-gpus) +we using 8gpu, but model size (aheads4-adim256) small than it. diff --git a/examples/librispeech/s2/conf/chunk_conformer.yaml b/examples/librispeech/s2/conf/chunk_conformer.yaml deleted file mode 100644 index afd2b051..00000000 --- a/examples/librispeech/s2/conf/chunk_conformer.yaml +++ /dev/null @@ -1,122 +0,0 @@ -# https://yaml.org/type/float.html -data: - train_manifest: data/manifest.train - dev_manifest: data/manifest.dev - test_manifest: data/manifest.test - min_input_len: 0.5 - max_input_len: 20.0 - min_output_len: 0.0 - max_output_len: 400.0 - min_output_input_ratio: 0.05 - max_output_input_ratio: 10.0 - -collator: - vocab_filepath: data/vocab.txt - unit_type: 'spm' - spm_model_prefix: 'data/bpe_unigram_5000' - mean_std_filepath: "" - augmentation_config: conf/augmentation.json - batch_size: 16 - raw_wav: True # use raw_wav or kaldi feature - spectrum_type: fbank #linear, mfcc, fbank - feat_dim: 80 - delta_delta: False - dither: 1.0 - target_sample_rate: 16000 - max_freq: None - n_fft: None - stride_ms: 10.0 - window_ms: 25.0 - use_dB_normalization: True - target_dB: -20 - random_seed: 0 - keep_transcription_text: False - sortagrad: True - shuffle_method: batch_shuffle - num_workers: 2 - - -# network architecture -model: - cmvn_file: "data/mean_std.json" - cmvn_file_type: "json" - # encoder related - encoder: conformer - encoder_conf: - output_size: 256 # dimension of attention - attention_heads: 4 - linear_units: 2048 # the number of units of position-wise feed forward - num_blocks: 12 # the number of encoder blocks - dropout_rate: 0.1 - positional_dropout_rate: 0.1 - attention_dropout_rate: 0.0 - input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 - normalize_before: True - use_cnn_module: True - cnn_module_kernel: 15 - activation_type: 'swish' - pos_enc_layer_type: 'rel_pos' - selfattention_layer_type: 'rel_selfattn' - causal: True - use_dynamic_chunk: true - cnn_module_norm: 'layer_norm' # using nn.LayerNorm makes model converge faster - use_dynamic_left_chunk: false - - # decoder related - decoder: transformer - decoder_conf: - attention_heads: 4 - linear_units: 2048 - num_blocks: 6 - dropout_rate: 0.1 - positional_dropout_rate: 0.1 - self_attention_dropout_rate: 0.0 - src_attention_dropout_rate: 0.0 - - # hybrid CTC/attention - model_conf: - ctc_weight: 0.3 - ctc_dropoutrate: 0.0 - ctc_grad_norm_type: null - lsm_weight: 0.1 # label smoothing option - length_normalized_loss: false - - -training: - n_epoch: 240 - accum_grad: 8 - global_grad_clip: 5.0 - optim: adam - optim_conf: - lr: 0.001 - weight_decay: 1e-06 - scheduler: warmuplr # pytorch v1.1.0+ required - scheduler_conf: - warmup_steps: 25000 - lr_decay: 1.0 - log_interval: 100 - checkpoint: - kbest_n: 50 - latest_n: 5 - - -decoding: - batch_size: 128 - error_rate_type: wer - decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' - lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm - alpha: 2.5 - beta: 0.3 - beam_size: 10 - cutoff_prob: 1.0 - cutoff_top_n: 0 - num_proc_bsearch: 8 - ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. - decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. - # <0: for decoding, use full chunk. - # >0: for decoding, use fixed chunk size as set. - # 0: used for training, it's prohibited here. - num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. - simulate_streaming: true # simulate streaming inference. Defaults to False. - - diff --git a/examples/librispeech/s2/conf/chunk_transformer.yaml b/examples/librispeech/s2/conf/chunk_transformer.yaml deleted file mode 100644 index 721bb7d9..00000000 --- a/examples/librispeech/s2/conf/chunk_transformer.yaml +++ /dev/null @@ -1,115 +0,0 @@ -# https://yaml.org/type/float.html -data: - train_manifest: data/manifest.train - dev_manifest: data/manifest.dev - test_manifest: data/manifest.test - min_input_len: 0.5 # second - max_input_len: 20.0 # second - min_output_len: 0.0 # tokens - max_output_len: 400.0 # tokens - min_output_input_ratio: 0.05 - max_output_input_ratio: 10.0 - -collator: - vocab_filepath: data/vocab.txt - unit_type: 'spm' - spm_model_prefix: 'data/bpe_unigram_5000' - mean_std_filepath: "" - augmentation_config: conf/augmentation.json - batch_size: 64 - raw_wav: True # use raw_wav or kaldi feature - spectrum_type: fbank #linear, mfcc, fbank - feat_dim: 80 - delta_delta: False - dither: 1.0 - target_sample_rate: 16000 - max_freq: None - n_fft: None - stride_ms: 10.0 - window_ms: 25.0 - use_dB_normalization: True - target_dB: -20 - random_seed: 0 - keep_transcription_text: False - sortagrad: True - shuffle_method: batch_shuffle - num_workers: 2 - - -# network architecture -model: - cmvn_file: "data/mean_std.json" - cmvn_file_type: "json" - # encoder related - encoder: transformer - encoder_conf: - output_size: 256 # dimension of attention - attention_heads: 4 - linear_units: 2048 # the number of units of position-wise feed forward - num_blocks: 12 # the number of encoder blocks - dropout_rate: 0.1 - positional_dropout_rate: 0.1 - attention_dropout_rate: 0.0 - input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 - normalize_before: true - use_dynamic_chunk: true - use_dynamic_left_chunk: false - - # decoder related - decoder: transformer - decoder_conf: - attention_heads: 4 - linear_units: 2048 - num_blocks: 6 - dropout_rate: 0.1 - positional_dropout_rate: 0.1 - self_attention_dropout_rate: 0.0 - src_attention_dropout_rate: 0.0 - - # hybrid CTC/attention - model_conf: - ctc_weight: 0.3 - ctc_dropoutrate: 0.0 - ctc_grad_norm_type: null - lsm_weight: 0.1 # label smoothing option - length_normalized_loss: false - - -training: - n_epoch: 120 - accum_grad: 1 - global_grad_clip: 5.0 - optim: adam - optim_conf: - lr: 0.001 - weight_decay: 1e-06 - scheduler: warmuplr # pytorch v1.1.0+ required - scheduler_conf: - warmup_steps: 25000 - lr_decay: 1.0 - log_interval: 100 - checkpoint: - kbest_n: 50 - latest_n: 5 - - -decoding: - batch_size: 64 - error_rate_type: wer - decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' - lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm - alpha: 2.5 - beta: 0.3 - beam_size: 10 - cutoff_prob: 1.0 - cutoff_top_n: 0 - num_proc_bsearch: 8 - ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. - decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. - # <0: for decoding, use full chunk. - # >0: for decoding, use fixed chunk size as set. - # 0: used for training, it's prohibited here. - num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. - simulate_streaming: true # simulate streaming inference. Defaults to False. - - diff --git a/examples/librispeech/s2/conf/conformer.yaml b/examples/librispeech/s2/conf/conformer.yaml deleted file mode 100644 index ef87753c..00000000 --- a/examples/librispeech/s2/conf/conformer.yaml +++ /dev/null @@ -1,118 +0,0 @@ -# https://yaml.org/type/float.html -data: - train_manifest: data/manifest.train - dev_manifest: data/manifest.dev - test_manifest: data/manifest.test-clean - min_input_len: 0.5 # seconds - max_input_len: 20.0 # seconds - min_output_len: 0.0 # tokens - max_output_len: 400.0 # tokens - min_output_input_ratio: 0.05 - max_output_input_ratio: 10.0 - -collator: - vocab_filepath: data/vocab.txt - unit_type: 'spm' - spm_model_prefix: 'data/bpe_unigram_5000' - mean_std_filepath: "" - augmentation_config: conf/augmentation.json - batch_size: 16 - raw_wav: True # use raw_wav or kaldi feature - spectrum_type: fbank #linear, mfcc, fbank - feat_dim: 80 - delta_delta: False - dither: 1.0 - target_sample_rate: 16000 - max_freq: None - n_fft: None - stride_ms: 10.0 - window_ms: 25.0 - use_dB_normalization: True - target_dB: -20 - random_seed: 0 - keep_transcription_text: False - sortagrad: True - shuffle_method: batch_shuffle - num_workers: 2 - - -# network architecture -model: - cmvn_file: "data/mean_std.json" - cmvn_file_type: "json" - # encoder related - encoder: conformer - encoder_conf: - output_size: 256 # dimension of attention - attention_heads: 4 - linear_units: 2048 # the number of units of position-wise feed forward - num_blocks: 12 # the number of encoder blocks - dropout_rate: 0.1 - positional_dropout_rate: 0.1 - attention_dropout_rate: 0.0 - input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 - normalize_before: True - use_cnn_module: True - cnn_module_kernel: 15 - activation_type: 'swish' - pos_enc_layer_type: 'rel_pos' - selfattention_layer_type: 'rel_selfattn' - - # decoder related - decoder: transformer - decoder_conf: - attention_heads: 4 - linear_units: 2048 - num_blocks: 6 - dropout_rate: 0.1 - positional_dropout_rate: 0.1 - self_attention_dropout_rate: 0.0 - src_attention_dropout_rate: 0.0 - - # hybrid CTC/attention - model_conf: - ctc_weight: 0.3 - ctc_dropoutrate: 0.0 - ctc_grad_norm_type: null - lsm_weight: 0.1 # label smoothing option - length_normalized_loss: false - - -training: - n_epoch: 120 - accum_grad: 8 - global_grad_clip: 3.0 - optim: adam - optim_conf: - lr: 0.004 - weight_decay: 1e-06 - scheduler: warmuplr # pytorch v1.1.0+ required - scheduler_conf: - warmup_steps: 25000 - lr_decay: 1.0 - log_interval: 100 - checkpoint: - kbest_n: 50 - latest_n: 5 - - -decoding: - batch_size: 64 - error_rate_type: wer - decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' - lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm - alpha: 2.5 - beta: 0.3 - beam_size: 10 - cutoff_prob: 1.0 - cutoff_top_n: 0 - num_proc_bsearch: 8 - ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. - decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. - # <0: for decoding, use full chunk. - # >0: for decoding, use fixed chunk size as set. - # 0: used for training, it's prohibited here. - num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. - simulate_streaming: False # simulate streaming inference. Defaults to False. - - diff --git a/examples/librispeech/s2/conf/fbank.conf b/examples/librispeech/s2/conf/fbank.conf new file mode 100644 index 00000000..82ac7bd0 --- /dev/null +++ b/examples/librispeech/s2/conf/fbank.conf @@ -0,0 +1,2 @@ +--sample-frequency=16000 +--num-mel-bins=80 diff --git a/examples/librispeech/s2/conf/lm/transformer.yaml b/examples/librispeech/s2/conf/lm/transformer.yaml new file mode 100644 index 00000000..4349f795 --- /dev/null +++ b/examples/librispeech/s2/conf/lm/transformer.yaml @@ -0,0 +1,13 @@ +model_module: transformer +model: + n_vocab: 5002 + pos_enc: null + embed_unit: 128 + att_unit: 512 + head: 8 + unit: 2048 + layer: 16 + dropout_rate: 0.5 + emb_dropout_rate: 0.0 + att_dropout_rate: 0.0 + tie_weights: False diff --git a/examples/librispeech/s2/conf/pitch.conf b/examples/librispeech/s2/conf/pitch.conf new file mode 100644 index 00000000..e959a19d --- /dev/null +++ b/examples/librispeech/s2/conf/pitch.conf @@ -0,0 +1 @@ +--sample-frequency=16000 diff --git a/examples/librispeech/s2/conf/transformer.yaml b/examples/librispeech/s2/conf/transformer.yaml index c9eed4f9..b2babca7 100644 --- a/examples/librispeech/s2/conf/transformer.yaml +++ b/examples/librispeech/s2/conf/transformer.yaml @@ -5,9 +5,9 @@ data: test_manifest: data/manifest.test-clean collator: - vocab_filepath: data/bpe_unigram_5000_units.txt + vocab_filepath: data/lang_char/train_960_unigram5000_units.txt unit_type: spm - spm_model_prefix: data/bpe_unigram_5000 + spm_model_prefix: data/lang_char/train_960_unigram5000 feat_dim: 83 stride_ms: 10.0 window_ms: 25.0 diff --git a/examples/librispeech/s2/local/data.sh b/examples/librispeech/s2/local/data.sh index 56fec846..b232f35a 100755 --- a/examples/librispeech/s2/local/data.sh +++ b/examples/librispeech/s2/local/data.sh @@ -2,19 +2,42 @@ stage=-1 stop_stage=100 +nj=32 +debugmode=1 +dumpdir=dump # directory to dump full features +N=0 # number of minibatches to be used (mainly for debugging). "0" uses all minibatches. +verbose=0 # verbose option +resume= # Resume the training from snapshot + +# feature configuration +do_delta=false + +# Set this to somewhere where you want to put your data, or where +# someone else has already put it. You'll want to change this +# if you're not on the CLSP grid. +datadir=${MAIN_ROOT}/examples/dataset/ # bpemode (unigram or bpe) nbpe=5000 bpemode=unigram -bpeprefix="data/bpe_${bpemode}_${nbpe}" source ${MAIN_ROOT}/utils/parse_options.sh +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +train_set=train_960 +train_sp=train_sp +train_dev=dev +recog_set="test_clean test_other dev_clean dev_other" + mkdir -p data TARGET_DIR=${MAIN_ROOT}/examples/dataset mkdir -p ${TARGET_DIR} - if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then # download data, generate manifests python3 ${TARGET_DIR}/librispeech/librispeech.py \ @@ -46,63 +69,98 @@ if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then fi if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then - # compute mean and stddev for normalizer - num_workers=$(nproc) - python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ - --manifest_path="data/manifest.train.raw" \ - --num_samples=-1 \ - --spectrum_type="fbank" \ - --feat_dim=80 \ - --delta_delta=false \ - --sample_rate=16000 \ - --stride_ms=10.0 \ - --window_ms=25.0 \ - --use_dB_normalization=False \ - --num_workers=${num_workers} \ - --output_path="data/mean_std.json" - - if [ $? -ne 0 ]; then - echo "Compute mean and stddev failed. Terminated." - exit 1 - fi + ### Task dependent. You have to make data the following preparation part by yourself. + ### But you can utilize Kaldi recipes in most cases + echo "stage 0: Data preparation" + for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do + # use underscore-separated names in data directories. + local/data_prep.sh ${datadir}/librispeech/${part}/LibriSpeech/${part} data/${part//-/_} + done fi +feat_tr_dir=${dumpdir}/${train_set}/delta${do_delta}; mkdir -p ${feat_tr_dir} +feat_sp_dir=${dumpdir}/${train_sp}/delta${do_delta}; mkdir -p ${feat_sp_dir} +feat_dt_dir=${dumpdir}/${train_dev}/delta${do_delta}; mkdir -p ${feat_dt_dir} if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then - # build vocabulary - python3 ${MAIN_ROOT}/utils/build_vocab.py \ - --unit_type "spm" \ - --spm_vocab_size=${nbpe} \ - --spm_mode ${bpemode} \ - --spm_model_prefix ${bpeprefix} \ - --vocab_path="data/vocab.txt" \ - --manifest_paths="data/manifest.train.raw" + ### Task dependent. You have to design training and dev sets by yourself. + ### But you can utilize Kaldi recipes in most cases + echo "stage 1: Feature Generation" + fbankdir=fbank + # Generate the fbank features; by default 80-dimensional fbanks with pitch on each frame + for x in dev_clean test_clean dev_other test_other train_clean_100 train_clean_360 train_other_500; do + steps/make_fbank_pitch.sh --cmd "$train_cmd" --nj ${nj} --write_utt2num_frames true \ + data/${x} exp/make_fbank/${x} ${fbankdir} + utils/fix_data_dir.sh data/${x} + done - if [ $? -ne 0 ]; then - echo "Build vocabulary failed. Terminated." - exit 1 - fi + utils/combine_data.sh --extra_files utt2num_frames data/${train_set}_org data/train_clean_100 data/train_clean_360 data/train_other_500 + utils/combine_data.sh --extra_files utt2num_frames data/${train_dev}_org data/dev_clean data/dev_other + utils/perturb_data_dir_speed.sh 0.9 data/${train_set}_org data/temp1 + utils/perturb_data_dir_speed.sh 1.0 data/${train_set}_org data/temp2 + utils/perturb_data_dir_speed.sh 1.1 data/${train_set}_org data/temp3 + + utils/combine_data.sh --extra-files utt2uniq data/${train_sp}_org data/temp1 data/temp2 data/temp3 + + # remove utt having more than 3000 frames + # remove utt having more than 400 characters + remove_longshortdata.sh --maxframes 3000 --maxchars 400 data/${train_set}_org data/${train_set} + remove_longshortdata.sh --maxframes 3000 --maxchars 400 data/${train_sp}_org data/${train_sp} + remove_longshortdata.sh --maxframes 3000 --maxchars 400 data/${train_dev}_org data/${train_dev} + steps/make_fbank_pitch.sh --cmd "$train_cmd" --nj $nj --write_utt2num_frames true \ + data/train_sp exp/make_fbank/train_sp ${fbankdir} + utils/fix_data_dir.sh data/train_sp + # compute global CMVN + compute-cmvn-stats scp:data/${train_sp}/feats.scp data/${train_sp}/cmvn.ark + + # dump features for training + dump.sh --cmd "$train_cmd" --nj ${nj} --do_delta ${do_delta} \ + data/${train_sp}/feats.scp data/${train_sp}/cmvn.ark exp/dump_feats/train ${feat_sp_dir} + dump.sh --cmd "$train_cmd" --nj ${nj} --do_delta ${do_delta} \ + data/${train_dev}/feats.scp data/${train_sp}/cmvn.ark exp/dump_feats/dev ${feat_dt_dir} + for rtask in ${recog_set}; do + feat_recog_dir=${dumpdir}/${rtask}/delta${do_delta}; mkdir -p ${feat_recog_dir} + dump.sh --cmd "$train_cmd" --nj ${nj} --do_delta ${do_delta} \ + data/${rtask}/feats.scp data/${train_sp}/cmvn.ark exp/dump_feats/recog/${rtask} \ + ${feat_recog_dir} + done fi +dict=data/lang_char/${train_set}_${bpemode}${nbpe}_units.txt +bpemodel=data/lang_char/${train_set}_${bpemode}${nbpe} +echo "dictionary: ${dict}" if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then - # format manifest with tokenids, vocab size - for set in train dev test dev-clean dev-other test-clean test-other; do - { - python3 ${MAIN_ROOT}/utils/format_data.py \ - --feat_type "raw" \ - --cmvn_path "data/mean_std.json" \ - --unit_type "spm" \ - --spm_model_prefix ${bpeprefix} \ - --vocab_path="data/vocab.txt" \ - --manifest_path="data/manifest.${set}.raw" \ - --output_path="data/manifest.${set}" - - if [ $? -ne 0 ]; then - echo "Formt mnaifest failed. Terminated." - exit 1 - fi - }& + ### Task dependent. You have to check non-linguistic symbols used in the corpus. + echo "stage 2: Dictionary and Json Data Preparation" + mkdir -p data/lang_char/ + echo " 1" > ${dict} # must be 1, 0 will be used for "blank" in CTC + cut -f 2- -d" " data/${train_set}/text > data/lang_char/input.txt + spm_train --input=data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000 + spm_encode --model=${bpemodel}.model --output_format=piece < data/lang_char/input.txt | tr ' ' '\n' | sort | uniq | awk '{print $0 " " NR+1}' >> ${dict} + wc -l ${dict} + + # make json labels + data2json.sh --nj ${nj} --feat ${feat_sp_dir}/feats.scp --bpecode ${bpemodel}.model \ + data/${train_sp} ${dict} > ${feat_sp_dir}/data_${bpemode}${nbpe}.json + data2json.sh --nj ${nj} --feat ${feat_dt_dir}/feats.scp --bpecode ${bpemodel}.model \ + data/${train_dev} ${dict} > ${feat_dt_dir}/data_${bpemode}${nbpe}.json + + for rtask in ${recog_set}; do + feat_recog_dir=${dumpdir}/${rtask}/delta${do_delta} + data2json.sh --nj ${nj} --feat ${feat_recog_dir}/feats.scp --bpecode ${bpemodel}.model \ + data/${rtask} ${dict} > ${feat_recog_dir}/data_${bpemode}${nbpe}.json + done +fi + + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # make json labels + python3 local/espnet_json_to_manifest.py --json-file ${feat_sp_dir}/data_${bpemode}${nbpe}.json --manifest-file data/manifest.train + python3 local/espnet_json_to_manifest.py --json-file ${feat_dt_dir}/data_${bpemode}${nbpe}.json --manifest-file data/manifest.dev + + for rtask in ${recog_set}; do + feat_recog_dir=${dumpdir}/${rtask}/delta${do_delta} + python3 local/espnet_json_to_manifest.py --json-file ${feat_recog_dir}/data_${bpemode}${nbpe}.json --manifest-file data/manifest.${rtask//_/-} done - wait fi echo "LibriSpeech Data preparation done." diff --git a/examples/librispeech/s2/local/data_prep.sh b/examples/librispeech/s2/local/data_prep.sh new file mode 100755 index 00000000..c903d45b --- /dev/null +++ b/examples/librispeech/s2/local/data_prep.sh @@ -0,0 +1,85 @@ +#!/usr/bin/env bash + +# Copyright 2014 Vassil Panayotov +# 2014 Johns Hopkins University (author: Daniel Povey) +# Apache 2.0 + +if [ "$#" -ne 2 ]; then + echo "Usage: $0 " + echo "e.g.: $0 /export/a15/vpanayotov/data/LibriSpeech/dev-clean data/dev-clean" + exit 1 +fi + +src=$1 +dst=$2 + +# all utterances are FLAC compressed +if ! which flac >&/dev/null; then + echo "Please install 'flac' on ALL worker nodes!" + exit 1 +fi + +spk_file=$src/../SPEAKERS.TXT + +mkdir -p $dst || exit 1 + +[ ! -d $src ] && echo "$0: no such directory $src" && exit 1 +[ ! -f $spk_file ] && echo "$0: expected file $spk_file to exist" && exit 1 + + +wav_scp=$dst/wav.scp; [[ -f "$wav_scp" ]] && rm $wav_scp +trans=$dst/text; [[ -f "$trans" ]] && rm $trans +utt2spk=$dst/utt2spk; [[ -f "$utt2spk" ]] && rm $utt2spk +spk2gender=$dst/spk2gender; [[ -f $spk2gender ]] && rm $spk2gender + +for reader_dir in $(find -L $src -mindepth 1 -maxdepth 1 -type d | sort); do + reader=$(basename $reader_dir) + if ! [ $reader -eq $reader ]; then # not integer. + echo "$0: unexpected subdirectory name $reader" + exit 1 + fi + + reader_gender=$(egrep "^$reader[ ]+\|" $spk_file | awk -F'|' '{gsub(/[ ]+/, ""); print tolower($2)}') + if [ "$reader_gender" != 'm' ] && [ "$reader_gender" != 'f' ]; then + echo "Unexpected gender: '$reader_gender'" + exit 1 + fi + + for chapter_dir in $(find -L $reader_dir/ -mindepth 1 -maxdepth 1 -type d | sort); do + chapter=$(basename $chapter_dir) + if ! [ "$chapter" -eq "$chapter" ]; then + echo "$0: unexpected chapter-subdirectory name $chapter" + exit 1 + fi + + find -L $chapter_dir/ -iname "*.flac" | sort | xargs -I% basename % .flac | \ + awk -v "dir=$chapter_dir" '{printf "%s flac -c -d -s %s/%s.flac |\n", $0, dir, $0}' >>$wav_scp|| exit 1 + + chapter_trans=$chapter_dir/${reader}-${chapter}.trans.txt + [ ! -f $chapter_trans ] && echo "$0: expected file $chapter_trans to exist" && exit 1 + cat $chapter_trans >>$trans + + # NOTE: For now we are using per-chapter utt2spk. That is each chapter is considered + # to be a different speaker. This is done for simplicity and because we want + # e.g. the CMVN to be calculated per-chapter + awk -v "reader=$reader" -v "chapter=$chapter" '{printf "%s %s-%s\n", $1, reader, chapter}' \ + <$chapter_trans >>$utt2spk || exit 1 + + # reader -> gender map (again using per-chapter granularity) + echo "${reader}-${chapter} $reader_gender" >>$spk2gender + done +done + +spk2utt=$dst/spk2utt +utils/utt2spk_to_spk2utt.pl <$utt2spk >$spk2utt || exit 1 + +ntrans=$(wc -l <$trans) +nutt2spk=$(wc -l <$utt2spk) +! [ "$ntrans" -eq "$nutt2spk" ] && \ + echo "Inconsistent #transcripts($ntrans) and #utt2spk($nutt2spk)" && exit 1 + +utils/validate_data_dir.sh --no-feats $dst || exit 1 + +echo "$0: successfully prepared data in $dst" + +exit 0 diff --git a/examples/librispeech/s2/local/recog.sh b/examples/librispeech/s2/local/recog.sh index df3846c0..f0e96109 100755 --- a/examples/librispeech/s2/local/recog.sh +++ b/examples/librispeech/s2/local/recog.sh @@ -11,22 +11,24 @@ tag= decode_config=conf/decode/decode.yaml # lm params -lang_model=rnnlm.model.best -lmexpdir=exp/train_rnnlm_pytorch_lm_transformer_cosine_batchsize32_lr1e-4_layer16_unigram5000_ngpu4/ -lmtag='nolm' +lang_model=transformerLM.pdparams +lmexpdir=exp/lm/transformer +rnnlm_config_path=conf/lm/transformer.yaml +lmtag='transformer' +train_set=train_960 recog_set="test-clean test-other dev-clean dev-other" recog_set="test-clean" # bpemode (unigram or bpe) nbpe=5000 bpemode=unigram -bpeprefix="data/bpe_${bpemode}_${nbpe}" +bpeprefix=data/lang_char/${train_set}_${bpemode}${nbpe} bpemodel=${bpeprefix}.model # bin params config_path=conf/transformer.yaml -dict=data/bpe_unigram_5000_units.txt +dict=data/lang_char/${train_set}_${bpemode}${nbpe}_units.txt ckpt_prefix= source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; @@ -90,9 +92,9 @@ for dmethd in join_ctc; do --recog-json ${feat_recog_dir}/split${nj}/JOB/manifest.${rtask} \ --result-label ${decode_dir}/data.JOB.json \ --model-conf ${config_path} \ - --model ${ckpt_prefix}.pdparams - - #--rnnlm ${lmexpdir}/${lang_model} \ + --model ${ckpt_prefix}.pdparams \ + --rnnlm-conf ${rnnlm_config_path} \ + --rnnlm ${lmexpdir}/${lang_model} score_sclite.sh --bpe ${nbpe} --bpemodel ${bpemodel} --wer false ${decode_dir} ${dict} diff --git a/examples/librispeech/s2/local/test.sh b/examples/librispeech/s2/local/test.sh index 5f662d29..23670f74 100755 --- a/examples/librispeech/s2/local/test.sh +++ b/examples/librispeech/s2/local/test.sh @@ -8,17 +8,18 @@ nj=32 lmtag='nolm' +train_set=train_960 recog_set="test-clean test-other dev-clean dev-other" recog_set="test-clean" # bpemode (unigram or bpe) nbpe=5000 bpemode=unigram -bpeprefix="data/bpe_${bpemode}_${nbpe}" +bpeprefix=data/lang_char/${train_set}_${bpemode}${nbpe} bpemodel=${bpeprefix}.model config_path=conf/transformer.yaml -dict=data/bpe_unigram_5000_units.txt +dict=data/lang_char/${train_set}_${bpemode}${nbpe}_units.txt ckpt_prefix= source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; diff --git a/examples/librispeech/s2/path.sh b/examples/librispeech/s2/path.sh index eec437b6..32ff28c1 100644 --- a/examples/librispeech/s2/path.sh +++ b/examples/librispeech/s2/path.sh @@ -1,6 +1,6 @@ export MAIN_ROOT=`realpath ${PWD}/../../../` -export PATH=${MAIN_ROOT}:${MAIN_ROOT}/tools/sctk/bin:${PWD}/utils:${PATH} +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/tools/sctk/bin:${MAIN_ROOT}/utils:${PWD}/utils:${PATH} export LC_ALL=C export PYTHONDONTWRITEBYTECODE=1 @@ -13,3 +13,16 @@ export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ MODEL=u2_kaldi export BIN_DIR=${MAIN_ROOT}/deepspeech/exps/${MODEL}/bin + +# srilm +export LIBLBFGS=${MAIN_ROOT}/tools/liblbfgs-1.10 +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${LIBLBFGS}/lib/.libs +export SRILM=${MAIN_ROOT}/tools/srilm +export PATH=${PATH}:${SRILM}/bin:${SRILM}/bin/i686-m64 + +# Kaldi +export KALDI_ROOT=${MAIN_ROOT}/tools/kaldi +[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh +export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH +[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present, can not using Kaldi!" +[ -f $KALDI_ROOT/tools/config/common_path.sh ] && . $KALDI_ROOT/tools/config/common_path.sh \ No newline at end of file diff --git a/examples/librispeech/s2/run.sh b/examples/librispeech/s2/run.sh index 3c7569fb..61172d25 100755 --- a/examples/librispeech/s2/run.sh +++ b/examples/librispeech/s2/run.sh @@ -33,16 +33,24 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then - # test ckpt avg_n + # attetion resocre decoder ./local/test.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 fi -if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ] && ${use_lm} == true; then + # join ctc decoder, use transformerlm to score + if [ ! -f exp/lm/transformer/transformerLM.pdparams ]; then + wget https://deepspeech.bj.bcebos.com/transformer_lm/transformerLM.pdparams exp/lm/transformer/ + fi + bash local/recog.sh --ckpt_prefix exp/${ckpt}/checkpoints/${avg_ckpt} +fi + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then # ctc alignment of test data CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} ${dict_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 fi -if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then # export ckpt avg_n CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit fi diff --git a/examples/librispeech/s2/steps b/examples/librispeech/s2/steps new file mode 120000 index 00000000..995eeccb --- /dev/null +++ b/examples/librispeech/s2/steps @@ -0,0 +1 @@ +../../../tools/kaldi/egs/wsj/s5/steps/ \ No newline at end of file diff --git a/examples/librispeech/s2/utils b/examples/librispeech/s2/utils index 256f914a..f49247da 120000 --- a/examples/librispeech/s2/utils +++ b/examples/librispeech/s2/utils @@ -1 +1 @@ -../../../utils/ \ No newline at end of file +../../../tools/kaldi/egs/wsj/s5/utils \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index a7310a02..4878dbe3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,6 +23,7 @@ praatio~=4.1 pre-commit pybind11 pypinyin +python-dateutil pyworld resampy==0.2.2 sacrebleu @@ -41,3 +42,4 @@ visualdl==2.2.0 webrtcvad yacs yq +nara_wpe \ No newline at end of file diff --git a/setup.py b/setup.py index be17e0a4..a2e4c031 100644 --- a/setup.py +++ b/setup.py @@ -65,13 +65,6 @@ def _remove(files: str): def _post_install(install_lib_dir): - # apt - check_call("apt-get update -y") - check_call("apt-get install -y " + 'vim tig tree sox pkg-config ' + - 'libsndfile1 libflac-dev libogg-dev ' + - 'libvorbis-dev libboost-dev swig python3-dev ') - print("apt install.") - # tools/make tool_dir = HERE / "tools" _remove(tool_dir.glob("*.done")) diff --git a/setup.sh b/setup.sh index d3dd8207..aefdab98 100644 --- a/setup.sh +++ b/setup.sh @@ -10,7 +10,7 @@ fi if [ -e /etc/lsb-release ];then ${SUDO} apt-get update -y - ${SUDO} apt-get install -y jq vim tig tree sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev + ${SUDO} apt-get install -y bc jq vim tig tree sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev if [ $? != 0 ]; then error_msg "Please using Ubuntu or install pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev by user." exit -1 diff --git a/tools/Makefile b/tools/Makefile index 77b41a48..87107a53 100644 --- a/tools/Makefile +++ b/tools/Makefile @@ -10,7 +10,7 @@ WGET ?= wget --no-check-certificate .PHONY: all clean -all: virtualenv.done kenlm.done sox.done soxbindings.done mfa.done sclite.done +all: virtualenv.done apt.done kenlm.done sox.done soxbindings.done mfa.done sclite.done virtualenv.done: test -d venv || virtualenv -p $(PYTHON) venv @@ -21,6 +21,13 @@ clean: find -iname "*.pyc" -delete rm -rf kenlm + +apt.done: + apt update -y + apt install -y bc flac jq vim tig tree pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev + echo "check_certificate = off" >> ~/.wgetrc + touch apt.done + kenlm.done: # Ubuntu 16.04 透過 apt 會安裝 boost 1.58.0 # it seems that boost (1.54.0) requires higher version. After I switched to g++-5 it compiles normally. @@ -48,6 +55,13 @@ mfa.done: tar xvf montreal-forced-aligner_linux.tar.gz touch mfa.done +openblas.done: + bash extras/install_openblas.sh + touch openblas.done + +kaldi.done: openblas.done + bash extras/install_kaldi.sh + touch kaldi.done #== SCTK =============================================================================== # SCTK official repo does not have version tags. Here's the mapping: diff --git a/tools/extras/install_kaldi.sh b/tools/extras/install_kaldi.sh index b87232b0..3cdcd32d 100755 --- a/tools/extras/install_kaldi.sh +++ b/tools/extras/install_kaldi.sh @@ -16,7 +16,7 @@ else echo "$KALDI_DIR already exists!" fi -cd "$KALDI_DIR/tools" +pushd "$KALDI_DIR/tools" git pull # Prevent kaldi from switching default python version @@ -28,8 +28,12 @@ touch "python/.use_default_python" make -j4 pushd ../src -./configure --shared --use-cuda=no --static-math --mathlib=OPENBLAS --openblas-root=${KALDI_DIR}/../OpenBLAS/install +OPENBLAS_DIR=${KALDI_DIR}/../OpenBLAS +mkdir -p ${OPENBLAS_DIR}/install +./configure --shared --use-cuda=no --static-math --mathlib=OPENBLAS --openblas-root=${OPENBLAS_DIR}/install make clean -j && make depend -j && make -j4 popd +popd + echo "Done installing Kaldi." diff --git a/utils/__init__.py b/utils/__init__.py old mode 100644 new mode 100755 diff --git a/utils/apply-cmvn.py b/utils/apply-cmvn.py new file mode 100755 index 00000000..f80053fb --- /dev/null +++ b/utils/apply-cmvn.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +import argparse +import logging +from distutils.util import strtobool + +import kaldiio +import numpy + +from deepspeech.transform.cmvn import CMVN +from deepspeech.utils.cli_readers import file_reader_helper +from deepspeech.utils.cli_utils import get_commandline_args +from deepspeech.utils.cli_utils import is_scipy_wav_style +from deepspeech.utils.cli_writers import file_writer_helper + + +def get_parser(): + parser = argparse.ArgumentParser( + description="apply mean-variance normalization to files", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) + + parser.add_argument( + "--verbose", "-V", default=0, type=int, help="Verbose option") + parser.add_argument( + "--in-filetype", + type=str, + default="mat", + choices=["mat", "hdf5", "sound.hdf5", "sound"], + help="Specify the file format for the rspecifier. " + '"mat" is the matrix format in kaldi', ) + parser.add_argument( + "--stats-filetype", + type=str, + default="mat", + choices=["mat", "hdf5", "npy"], + help="Specify the file format for the rspecifier. " + '"mat" is the matrix format in kaldi', ) + parser.add_argument( + "--out-filetype", + type=str, + default="mat", + choices=["mat", "hdf5"], + help="Specify the file format for the wspecifier. " + '"mat" is the matrix format in kaldi', ) + + parser.add_argument( + "--norm-means", + type=strtobool, + default=True, + help="Do variance normalization or not.", ) + parser.add_argument( + "--norm-vars", + type=strtobool, + default=False, + help="Do variance normalization or not.", ) + parser.add_argument( + "--reverse", + type=strtobool, + default=False, + help="Do reverse mode or not") + parser.add_argument( + "--spk2utt", + type=str, + help="A text file of speaker to utterance-list map. " + "(Don't give rspecifier format, such as " + '"ark:spk2utt")', ) + parser.add_argument( + "--utt2spk", + type=str, + help="A text file of utterance to speaker map. " + "(Don't give rspecifier format, such as " + '"ark:utt2spk")', ) + parser.add_argument( + "--write-num-frames", + type=str, + help="Specify wspecifer for utt2num_frames") + parser.add_argument( + "--compress", + type=strtobool, + default=False, + help="Save in compressed format") + parser.add_argument( + "--compression-method", + type=int, + default=2, + help="Specify the method(if mat) or " + "gzip-level(if hdf5)", ) + parser.add_argument( + "stats_rspecifier_or_rxfilename", + help="Input stats. e.g. ark:stats.ark or stats.mat", ) + parser.add_argument( + "rspecifier", type=str, help="Read specifier id. e.g. ark:some.ark") + parser.add_argument( + "wspecifier", type=str, help="Write specifier id. e.g. ark:some.ark") + return parser + + +def main(): + args = get_parser().parse_args() + + # logging info + logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + if args.verbose > 0: + logging.basicConfig(level=logging.INFO, format=logfmt) + else: + logging.basicConfig(level=logging.WARN, format=logfmt) + logging.info(get_commandline_args()) + + if ":" in args.stats_rspecifier_or_rxfilename: + is_rspcifier = True + if args.stats_filetype == "npy": + stats_filetype = "hdf5" + else: + stats_filetype = args.stats_filetype + + stats_dict = dict( + file_reader_helper(args.stats_rspecifier_or_rxfilename, + stats_filetype)) + else: + is_rspcifier = False + if args.stats_filetype == "mat": + stats = kaldiio.load_mat(args.stats_rspecifier_or_rxfilename) + else: + stats = numpy.load(args.stats_rspecifier_or_rxfilename) + stats_dict = {None: stats} + + cmvn = CMVN( + stats=stats_dict, + norm_means=args.norm_means, + norm_vars=args.norm_vars, + utt2spk=args.utt2spk, + spk2utt=args.spk2utt, + reverse=args.reverse, ) + + with file_writer_helper( + args.wspecifier, + filetype=args.out_filetype, + write_num_frames=args.write_num_frames, + compress=args.compress, + compression_method=args.compression_method, ) as writer: + for utt, mat in file_reader_helper(args.rspecifier, args.in_filetype): + if is_scipy_wav_style(mat): + # If data is sound file, then got as Tuple[int, ndarray] + rate, mat = mat + mat = cmvn(mat, utt if is_rspcifier else None) + writer[utt] = mat + + +if __name__ == "__main__": + main() diff --git a/utils/avg_model.py b/utils/avg_model.py index 7c05ec78..6ee16408 100755 --- a/utils/avg_model.py +++ b/utils/avg_model.py @@ -47,8 +47,10 @@ def main(args): beat_val_scores = sorted_val_scores[:args.num, 1] selected_epochs = sorted_val_scores[:args.num, 0].astype(np.int64) + avg_val_score = np.mean(beat_val_scores) print("selected val scores = " + str(beat_val_scores)) print("selected epochs = " + str(selected_epochs)) + print("averaged val score = " + str(avg_val_score)) path_list = [ args.ckpt_dir + '/{}.pdparams'.format(int(epoch)) @@ -80,7 +82,7 @@ def main(args): data = json.dumps({ "mode": 'val_best' if args.val_best else 'latest', "avg_ckpt": args.dst_model, - "val_loss_mean": np.mean(beat_val_scores), + "val_loss_mean": avg_val_score, "ckpts": path_list, "epochs": selected_epochs.tolist(), "val_losses": beat_val_scores.tolist(), diff --git a/utils/caculate_rtf.py b/utils/caculate_rtf.py new file mode 100755 index 00000000..fcc155ed --- /dev/null +++ b/utils/caculate_rtf.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2021 Kyoto University (Hirofumi Inaguma) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +import argparse +import codecs +import glob +import os + +from dateutil import parser + + +def get_parser(): + parser = argparse.ArgumentParser( + description="calculate real time factor (RTF)") + parser.add_argument( + "--log-dir", + type=str, + default=None, + help="path to logging directory", ) + return parser + + +def main(): + + args = get_parser().parse_args() + + audio_sec = 0 + decode_sec = 0 + n_utt = 0 + + audio_durations = [] + start_times = [] + end_times = [] + for x in glob.glob(os.path.join(args.log_dir, "decode.*.log")): + with codecs.open(x, "r", "utf-8") as f: + for line in f: + x = line.strip() + # 2021-10-25 08:22:04.052 | INFO | xxx:recog_v2:188 - feat: (1570, 83) + if "feat:" in x: + dur = int(x.split("(")[1].split(',')[0]) + audio_durations += [dur] + start_times += [parser.parse(x.split("|")[0])] + elif "total log probability:" in x: + end_times += [parser.parse(x.split("|")[0])] + assert len(audio_durations) == len(end_times), (len(audio_durations), + len(end_times), ) + assert len(start_times) == len(end_times), (len(start_times), + len(end_times)) + + audio_sec += sum(audio_durations) / 100 # [sec] + decode_sec += sum([(end - start).total_seconds() + for start, end in zip(start_times, end_times)]) + n_utt += len(audio_durations) + + print("Total audio duration: %.3f [sec]" % audio_sec) + print("Total decoding time: %.3f [sec]" % decode_sec) + rtf = decode_sec / audio_sec if audio_sec > 0 else 0 + print("RTF: %.3f" % rtf) + latency = decode_sec * 1000 / n_utt if n_utt > 0 else 0 + print("Latency: %.3f [ms/sentence]" % latency) + + +if __name__ == "__main__": + main() diff --git a/utils/compute-cmvn-stats.py b/utils/compute-cmvn-stats.py new file mode 100755 index 00000000..706d8cd5 --- /dev/null +++ b/utils/compute-cmvn-stats.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +import argparse +import logging + +import kaldiio +import numpy as np + +from deepspeech.transform.transformation import Transformation +from deepspeech.utils.cli_readers import file_reader_helper +from deepspeech.utils.cli_utils import get_commandline_args +from deepspeech.utils.cli_utils import is_scipy_wav_style +from deepspeech.utils.cli_writers import file_writer_helper + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Compute cepstral mean and " + "variance normalization statistics" + "If wspecifier provided: per-utterance by default, " + "or per-speaker if" + "spk2utt option provided; if wxfilename: global", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) + parser.add_argument( + "--spk2utt", + type=str, + help="A text file of speaker to utterance-list map. " + "(Don't give rspecifier format, such as " + '"ark:utt2spk")', ) + parser.add_argument( + "--verbose", "-V", default=0, type=int, help="Verbose option") + parser.add_argument( + "--in-filetype", + type=str, + default="mat", + choices=["mat", "hdf5", "sound.hdf5", "sound"], + help="Specify the file format for the rspecifier. " + '"mat" is the matrix format in kaldi', ) + parser.add_argument( + "--out-filetype", + type=str, + default="mat", + choices=["mat", "hdf5", "npy"], + help="Specify the file format for the wspecifier. " + '"mat" is the matrix format in kaldi', ) + parser.add_argument( + "--preprocess-conf", + type=str, + default=None, + help="The configuration file for the pre-processing", ) + parser.add_argument( + "rspecifier", + type=str, + help="Read specifier for feats. e.g. ark:some.ark") + parser.add_argument( + "wspecifier_or_wxfilename", + type=str, + help="Write specifier. e.g. ark:some.ark") + return parser + + +def main(): + args = get_parser().parse_args() + + logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + if args.verbose > 0: + logging.basicConfig(level=logging.INFO, format=logfmt) + else: + logging.basicConfig(level=logging.WARN, format=logfmt) + logging.info(get_commandline_args()) + + is_wspecifier = ":" in args.wspecifier_or_wxfilename + + if is_wspecifier: + if args.spk2utt is not None: + logging.info("Performing as speaker CMVN mode") + utt2spk_dict = {} + with open(args.spk2utt) as f: + for line in f: + spk, utts = line.rstrip().split(None, 1) + for utt in utts.split(): + utt2spk_dict[utt] = spk + + def utt2spk(x): + return utt2spk_dict[x] + + else: + logging.info("Performing as utterance CMVN mode") + + def utt2spk(x): + return x + + if args.out_filetype == "npy": + logging.warning("--out-filetype npy is allowed only for " + "Global CMVN mode, changing to hdf5") + args.out_filetype = "hdf5" + + else: + logging.info("Performing as global CMVN mode") + if args.spk2utt is not None: + logging.warning("spk2utt is not used for global CMVN mode") + + def utt2spk(x): + return None + + if args.out_filetype == "hdf5": + logging.warning("--out-filetype hdf5 is not allowed for " + "Global CMVN mode, changing to npy") + args.out_filetype = "npy" + + if args.preprocess_conf is not None: + preprocessing = Transformation(args.preprocess_conf) + logging.info("Apply preprocessing: {}".format(preprocessing)) + else: + preprocessing = None + + # Calculate stats for each speaker + counts = {} + sum_feats = {} + square_sum_feats = {} + + idx = 0 + for idx, (utt, matrix) in enumerate( + file_reader_helper(args.rspecifier, args.in_filetype), 1): + if is_scipy_wav_style(matrix): + # If data is sound file, then got as Tuple[int, ndarray] + rate, matrix = matrix + if preprocessing is not None: + matrix = preprocessing(matrix, uttid_list=utt) + + spk = utt2spk(utt) + + # Init at the first seen of the spk + if spk not in counts: + counts[spk] = 0 + feat_shape = matrix.shape[1:] + # Accumulate in double precision + sum_feats[spk] = np.zeros(feat_shape, dtype=np.float64) + square_sum_feats[spk] = np.zeros(feat_shape, dtype=np.float64) + + counts[spk] += matrix.shape[0] + sum_feats[spk] += matrix.sum(axis=0) + square_sum_feats[spk] += (matrix**2).sum(axis=0) + logging.info("Processed {} utterances".format(idx)) + assert idx > 0, idx + + cmvn_stats = {} + for spk in counts: + feat_shape = sum_feats[spk].shape + cmvn_shape = (2, feat_shape[0] + 1) + feat_shape[1:] + _cmvn_stats = np.empty(cmvn_shape, dtype=np.float64) + _cmvn_stats[0, :-1] = sum_feats[spk] + _cmvn_stats[1, :-1] = square_sum_feats[spk] + + _cmvn_stats[0, -1] = counts[spk] + _cmvn_stats[1, -1] = 0.0 + + # You can get the mean and std as following, + # >>> N = _cmvn_stats[0, -1] + # >>> mean = _cmvn_stats[0, :-1] / N + # >>> std = np.sqrt(_cmvn_stats[1, :-1] / N - mean ** 2) + + cmvn_stats[spk] = _cmvn_stats + + # Per utterance or speaker CMVN + if is_wspecifier: + with file_writer_helper( + args.wspecifier_or_wxfilename, + filetype=args.out_filetype) as writer: + for spk, mat in cmvn_stats.items(): + writer[spk] = mat + + # Global CMVN + else: + matrix = cmvn_stats[None] + if args.out_filetype == "npy": + np.save(args.wspecifier_or_wxfilename, matrix) + elif args.out_filetype == "mat": + # Kaldi supports only matrix or vector + kaldiio.save_mat(args.wspecifier_or_wxfilename, matrix) + else: + raise RuntimeError( + "Not supporting: --out-filetype {}".format(args.out_filetype)) + + +if __name__ == "__main__": + main() diff --git a/utils/compute_statistics.py b/utils/compute_statistics.py old mode 100644 new mode 100755 diff --git a/utils/copy-feats.py b/utils/copy-feats.py new file mode 100755 index 00000000..7d1b8589 --- /dev/null +++ b/utils/copy-feats.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +import argparse +import logging +from distutils.util import strtobool + +from deepspeech.transform.transformation import Transformation +from deepspeech.utils.cli_readers import file_reader_helper +from deepspeech.utils.cli_utils import get_commandline_args +from deepspeech.utils.cli_utils import is_scipy_wav_style +from deepspeech.utils.cli_writers import file_writer_helper + + +def get_parser(): + parser = argparse.ArgumentParser( + description="copy feature with preprocessing", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) + + parser.add_argument( + "--verbose", "-V", default=0, type=int, help="Verbose option") + parser.add_argument( + "--in-filetype", + type=str, + default="mat", + choices=["mat", "hdf5", "sound.hdf5", "sound"], + help="Specify the file format for the rspecifier. " + '"mat" is the matrix format in kaldi', ) + parser.add_argument( + "--out-filetype", + type=str, + default="mat", + choices=["mat", "hdf5", "sound.hdf5", "sound"], + help="Specify the file format for the wspecifier. " + '"mat" is the matrix format in kaldi', ) + parser.add_argument( + "--write-num-frames", + type=str, + help="Specify wspecifer for utt2num_frames") + parser.add_argument( + "--compress", + type=strtobool, + default=False, + help="Save in compressed format") + parser.add_argument( + "--compression-method", + type=int, + default=2, + help="Specify the method(if mat) or " + "gzip-level(if hdf5)", ) + parser.add_argument( + "--preprocess-conf", + type=str, + default=None, + help="The configuration file for the pre-processing", ) + parser.add_argument( + "rspecifier", + type=str, + help="Read specifier for feats. e.g. ark:some.ark") + parser.add_argument( + "wspecifier", type=str, help="Write specifier. e.g. ark:some.ark") + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + # logging info + logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + if args.verbose > 0: + logging.basicConfig(level=logging.INFO, format=logfmt) + else: + logging.basicConfig(level=logging.WARN, format=logfmt) + logging.info(get_commandline_args()) + + if args.preprocess_conf is not None: + preprocessing = Transformation(args.preprocess_conf) + logging.info("Apply preprocessing: {}".format(preprocessing)) + else: + preprocessing = None + + with file_writer_helper( + args.wspecifier, + filetype=args.out_filetype, + write_num_frames=args.write_num_frames, + compress=args.compress, + compression_method=args.compression_method, ) as writer: + for utt, mat in file_reader_helper(args.rspecifier, args.in_filetype): + if is_scipy_wav_style(mat): + # If data is sound file, then got as Tuple[int, ndarray] + rate, mat = mat + + if preprocessing is not None: + mat = preprocessing(mat, uttid_list=utt) + + # shape = (Time, Channel) + if args.out_filetype in ["sound.hdf5", "sound"]: + # Write Tuple[int, numpy.ndarray] (scipy style) + writer[utt] = (rate, mat) + else: + writer[utt] = mat + + +if __name__ == "__main__": + main() diff --git a/utils/data2json.sh b/utils/data2json.sh new file mode 100755 index 00000000..25131437 --- /dev/null +++ b/utils/data2json.sh @@ -0,0 +1,170 @@ +#!/usr/bin/env bash + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +echo "$0 $*" >&2 # Print the command line for logging +. ./path.sh + +nj=1 +cmd=run.pl +nlsyms="" +lang="" +feat="" # feat.scp +oov="" +bpecode="" +allow_one_column=false +verbose=0 +trans_type=char +filetype="" +preprocess_conf="" +category="" +out="" # If omitted, write in stdout + +text="" +multilingual=false + +help_message=$(cat << EOF +Usage: $0 +e.g. $0 data/train data/lang_1char/train_units.txt +Options: + --nj # number of parallel jobs + --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs. + --feat # feat.scp or feat1.scp,feat2.scp,... + --oov # Default: + --out # If omitted, write in stdout + --filetype # Specify the format of feats file + --preprocess-conf # Apply preprocess to feats when creating shape.scp + --verbose # Default: 0 +EOF +) +. utils/parse_options.sh + +if [ $# != 2 ]; then + echo "${help_message}" 1>&2 + exit 1; +fi + +set -euo pipefail + +dir=$1 +dic=$2 +tmpdir=$(mktemp -d ${dir}/tmp-XXXXX) +trap 'rm -rf ${tmpdir}' EXIT + +if [ -z ${text} ]; then + text=${dir}/text +fi + +# 1. Create scp files for inputs +# These are not necessary for decoding mode, and make it as an option +input= +if [ -n "${feat}" ]; then + _feat_scps=$(echo "${feat}" | tr ',' ' ' ) + read -r -a feat_scps <<< $_feat_scps + num_feats=${#feat_scps[@]} + + for (( i=1; i<=num_feats; i++ )); do + feat=${feat_scps[$((i-1))]} + mkdir -p ${tmpdir}/input_${i} + input+="input_${i} " + cat ${feat} > ${tmpdir}/input_${i}/feat.scp + + # Dump in the "legacy" style JSON format + if [ -n "${filetype}" ]; then + awk -v filetype=${filetype} '{print $1 " " filetype}' ${feat} \ + > ${tmpdir}/input_${i}/filetype.scp + fi + + feat_to_shape.sh --cmd "${cmd}" --nj ${nj} \ + --filetype "${filetype}" \ + --preprocess-conf "${preprocess_conf}" \ + --verbose ${verbose} ${feat} ${tmpdir}/input_${i}/shape.scp + done +fi + +# 2. Create scp files for outputs +mkdir -p ${tmpdir}/output +if [ -n "${bpecode}" ]; then + if [ ${multilingual} = true ]; then + # remove a space before the language ID + paste -d " " <(awk '{print $1}' ${text}) <(cut -f 2- -d" " ${text} \ + | spm_encode --model=${bpecode} --output_format=piece | cut -f 2- -d" ") \ + > ${tmpdir}/output/token.scp + else + paste -d " " <(awk '{print $1}' ${text}) <(cut -f 2- -d" " ${text} \ + | spm_encode --model=${bpecode} --output_format=piece) \ + > ${tmpdir}/output/token.scp + fi +elif [ -n "${nlsyms}" ]; then + text2token.py -s 1 -n 1 -l ${nlsyms} ${text} --trans_type ${trans_type} > ${tmpdir}/output/token.scp +else + text2token.py -s 1 -n 1 ${text} --trans_type ${trans_type} > ${tmpdir}/output/token.scp +fi +< ${tmpdir}/output/token.scp utils/sym2int.pl --map-oov ${oov} -f 2- ${dic} > ${tmpdir}/output/tokenid.scp +# +2 comes from CTC blank and EOS +vocsize=$(tail -n 1 ${dic} | awk '{print $2}') +odim=$(echo "$vocsize + 2" | bc) +< ${tmpdir}/output/tokenid.scp awk -v odim=${odim} '{print $1 " " NF-1 "," odim}' > ${tmpdir}/output/shape.scp + +cat ${text} > ${tmpdir}/output/text.scp + + +# 3. Create scp files for the others +mkdir -p ${tmpdir}/other +if [ ${multilingual} == true ]; then + awk '{ + n = split($1,S,"[-]"); + lang=S[n]; + print $1 " " lang + }' ${text} > ${tmpdir}/other/lang.scp +elif [ -n "${lang}" ]; then + awk -v lang=${lang} '{print $1 " " lang}' ${text} > ${tmpdir}/other/lang.scp +fi + +if [ -n "${category}" ]; then + awk -v category=${category} '{print $1 " " category}' ${dir}/text \ + > ${tmpdir}/other/category.scp +fi +cat ${dir}/utt2spk > ${tmpdir}/other/utt2spk.scp + +# 4. Merge scp files into a JSON file +opts="" +if [ -n "${feat}" ]; then + intypes="${input} output other" +else + intypes="output other" +fi +for intype in ${intypes}; do + if [ -z "$(find "${tmpdir}/${intype}" -name "*.scp")" ]; then + continue + fi + + if [ ${intype} != other ]; then + opts+="--${intype%_*}-scps " + else + opts+="--scps " + fi + + for x in "${tmpdir}/${intype}"/*.scp; do + k=$(basename ${x} .scp) + if [ ${k} = shape ]; then + opts+="shape:${x}:shape " + else + opts+="${k}:${x} " + fi + done +done + +if ${allow_one_column}; then + opts+="--allow-one-column true " +else + opts+="--allow-one-column false " +fi + +if [ -n "${out}" ]; then + opts+="-O ${out}" +fi +merge_scp2json.py --verbose ${verbose} ${opts} + +rm -fr ${tmpdir} diff --git a/utils/dump.sh b/utils/dump.sh new file mode 100755 index 00000000..1f312b3a --- /dev/null +++ b/utils/dump.sh @@ -0,0 +1,95 @@ +#!/usr/bin/env bash + +# Copyright 2017 Nagoya University (Tomoki Hayashi) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +echo "$0 $*" # Print the command line for logging +. ./path.sh + +cmd=run.pl +do_delta=false +nj=1 +verbose=0 +compress=true +write_utt2num_frames=true +filetype='mat' # mat or hdf5 +help_message="Usage: $0 " + +. utils/parse_options.sh + +scp=$1 +cvmnark=$2 +logdir=$3 +dumpdir=$4 + +if [ $# != 4 ]; then + echo "${help_message}" + exit 1; +fi + +set -euo pipefail + +mkdir -p ${logdir} +mkdir -p ${dumpdir} + +dumpdir=$(perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' ${dumpdir} ${PWD}) + +for n in $(seq ${nj}); do + # the next command does nothing unless $dumpdir/storage/ exists, see + # utils/create_data_link.pl for more info. + utils/create_data_link.pl ${dumpdir}/feats.${n}.ark +done + +if ${write_utt2num_frames}; then + write_num_frames_opt="--write-num-frames=ark,t:$dumpdir/utt2num_frames.JOB" +else + write_num_frames_opt= +fi + +# split scp file +split_scps="" +for n in $(seq ${nj}); do + split_scps="$split_scps $logdir/feats.$n.scp" +done + +utils/split_scp.pl ${scp} ${split_scps} || exit 1; + +# dump features +if ${do_delta}; then + ${cmd} JOB=1:${nj} ${logdir}/dump_feature.JOB.log \ + apply-cmvn --norm-vars=true ${cvmnark} scp:${logdir}/feats.JOB.scp ark:- \| \ + add-deltas ark:- ark:- \| \ + copy-feats.py --verbose ${verbose} --out-filetype ${filetype} \ + --compress=${compress} --compression-method=2 ${write_num_frames_opt} \ + ark:- ark,scp:${dumpdir}/feats.JOB.ark,${dumpdir}/feats.JOB.scp \ + || exit 1 +else + ${cmd} JOB=1:${nj} ${logdir}/dump_feature.JOB.log \ + apply-cmvn --norm-vars=true ${cvmnark} scp:${logdir}/feats.JOB.scp ark:- \| \ + copy-feats.py --verbose ${verbose} --out-filetype ${filetype} \ + --compress=${compress} --compression-method=2 ${write_num_frames_opt} \ + ark:- ark,scp:${dumpdir}/feats.JOB.ark,${dumpdir}/feats.JOB.scp \ + || exit 1 +fi + +# concatenate scp files +for n in $(seq ${nj}); do + cat ${dumpdir}/feats.${n}.scp || exit 1; +done > ${dumpdir}/feats.scp || exit 1 + +if ${write_utt2num_frames}; then + for n in $(seq ${nj}); do + cat ${dumpdir}/utt2num_frames.${n} || exit 1; + done > ${dumpdir}/utt2num_frames || exit 1 + rm ${dumpdir}/utt2num_frames.* 2>/dev/null +fi + +# Write the filetype, this will be used for data2json.sh +echo ${filetype} > ${dumpdir}/filetype + + +# remove temp scps +rm ${logdir}/feats.*.scp 2>/dev/null +if [ ${verbose} -eq 1 ]; then + echo "Succeeded dumping features for training" +fi diff --git a/utils/feat-to-shape.py b/utils/feat-to-shape.py new file mode 100755 index 00000000..911bf5cf --- /dev/null +++ b/utils/feat-to-shape.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +import argparse +import logging +import sys + +from deepspeech.transform.transformation import Transformation +from deepspeech.utils.cli_readers import file_reader_helper +from deepspeech.utils.cli_utils import get_commandline_args +from deepspeech.utils.cli_utils import is_scipy_wav_style + + +def get_parser(): + parser = argparse.ArgumentParser( + description="convert feature to its shape", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option") + parser.add_argument( + "--filetype", + type=str, + default="mat", + choices=["mat", "hdf5", "sound.hdf5", "sound"], + help="Specify the file format for the rspecifier. " + '"mat" is the matrix format in kaldi', + ) + parser.add_argument( + "--preprocess-conf", + type=str, + default=None, + help="The configuration file for the pre-processing", + ) + parser.add_argument( + "rspecifier", type=str, help="Read specifier for feats. e.g. ark:some.ark" + ) + parser.add_argument( + "out", + nargs="?", + type=argparse.FileType("w"), + default=sys.stdout, + help="The output filename. " "If omitted, then output to sys.stdout", + ) + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + # logging info + logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + if args.verbose > 0: + logging.basicConfig(level=logging.INFO, format=logfmt) + else: + logging.basicConfig(level=logging.WARN, format=logfmt) + logging.info(get_commandline_args()) + + if args.preprocess_conf is not None: + preprocessing = Transformation(args.preprocess_conf) + logging.info("Apply preprocessing: {}".format(preprocessing)) + else: + preprocessing = None + + # There are no necessary for matrix without preprocessing, + # so change to file_reader_helper to return shape. + # This make sense only with filetype="hdf5". + for utt, mat in file_reader_helper( + args.rspecifier, args.filetype, return_shape=preprocessing is None + ): + if preprocessing is not None: + if is_scipy_wav_style(mat): + # If data is sound file, then got as Tuple[int, ndarray] + rate, mat = mat + mat = preprocessing(mat, uttid_list=utt) + shape_str = ",".join(map(str, mat.shape)) + else: + if len(mat) == 2 and isinstance(mat[1], tuple): + # If data is sound file, Tuple[int, Tuple[int, ...]] + rate, mat = mat + shape_str = ",".join(map(str, mat)) + args.out.write("{} {}\n".format(utt, shape_str)) + + +if __name__ == "__main__": + main() diff --git a/utils/feat_to_shape.sh b/utils/feat_to_shape.sh new file mode 100755 index 00000000..7f4668c4 --- /dev/null +++ b/utils/feat_to_shape.sh @@ -0,0 +1,72 @@ +#!/usr/bin/env bash + +# Begin configuration section. +nj=4 +cmd=run.pl +verbose=0 +filetype="" +preprocess_conf="" +# End configuration section. + +help_message=$(cat << EOF +Usage: $0 [options] [] +e.g.: $0 data/train/feats.scp data/train/shape.scp data/train/log +Options: + --nj # number of parallel jobs + --cmd (utils/run.pl|utils/queue.pl ) # how to run jobs. + --filetype # Specify the format of feats file + --preprocess-conf # Apply preprocess to feats when creating shape.scp + --verbose # Default: 0 +EOF +) + +echo "$0 $*" 1>&2 # Print the command line for logging + +. parse_options.sh || exit 1; + +if [ $# -lt 2 ] || [ $# -gt 3 ]; then + echo "${help_message}" 1>&2 + exit 1; +fi + +set -euo pipefail + +scp=$1 +outscp=$2 +data=$(dirname ${scp}) +if [ $# -eq 3 ]; then + logdir=$3 +else + logdir=${data}/log +fi +mkdir -p ${logdir} + +nj=$((nj<$(<"${scp}" wc -l)?nj:$(<"${scp}" wc -l))) +split_scps="" +for n in $(seq ${nj}); do + split_scps="${split_scps} ${logdir}/feats.${n}.scp" +done + +utils/split_scp.pl ${scp} ${split_scps} + +if [ -n "${preprocess_conf}" ]; then + preprocess_opt="--preprocess-conf ${preprocess_conf}" +else + preprocess_opt="" +fi +if [ -n "${filetype}" ]; then + filetype_opt="--filetype ${filetype}" +else + filetype_opt="" +fi + +${cmd} JOB=1:${nj} ${logdir}/feat_to_shape.JOB.log \ + feat-to-shape.py --verbose ${verbose} ${preprocess_opt} ${filetype_opt} \ + scp:${logdir}/feats.JOB.scp ${logdir}/shape.JOB.scp + +# concatenate the .scp files together. +for n in $(seq ${nj}); do + cat ${logdir}/shape.${n}.scp +done > ${outscp} + +rm -f ${logdir}/feats.*.scp 2>/dev/null diff --git a/utils/gen_duration_from_textgrid.py b/utils/gen_duration_from_textgrid.py old mode 100644 new mode 100755 diff --git a/utils/merge_scp2json.py b/utils/merge_scp2json.py new file mode 100755 index 00000000..b724a7dd --- /dev/null +++ b/utils/merge_scp2json.py @@ -0,0 +1,289 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +import argparse +import codecs +import json +import logging +import sys +from distutils.util import strtobool +from io import open + +from deepspeech.utils.cli_utils import get_commandline_args + +PY2 = sys.version_info[0] == 2 +sys.stdin = codecs.getreader("utf-8")(sys.stdin if PY2 else sys.stdin.buffer) +sys.stdout = codecs.getwriter("utf-8")(sys.stdout if PY2 else sys.stdout.buffer) + + +# Special types: +def shape(x): + """Change str to List[int] + + >>> shape('3,5') + [3, 5] + >>> shape(' [3, 5] ') + [3, 5] + + """ + + # x: ' [3, 5] ' -> '3, 5' + x = x.strip() + if x[0] == "[": + x = x[1:] + if x[-1] == "]": + x = x[:-1] + + return list(map(int, x.split(","))) + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Given each file paths with such format as " + "::. type> can be omitted and the default " + 'is "str". e.g. {} ' + "--input-scps feat:data/feats.scp shape:data/utt2feat_shape:shape " + "--input-scps feat:data/feats2.scp shape:data/utt2feat2_shape:shape " + "--output-scps text:data/text shape:data/utt2text_shape:shape " + "--scps utt2spk:data/utt2spk".format(sys.argv[0]), + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) + parser.add_argument( + "--input-scps", + type=str, + nargs="*", + action="append", + default=[], + help="Json files for the inputs", ) + parser.add_argument( + "--output-scps", + type=str, + nargs="*", + action="append", + default=[], + help="Json files for the outputs", ) + parser.add_argument( + "--scps", + type=str, + nargs="+", + default=[], + help="The json files except for the input and outputs", ) + parser.add_argument( + "--verbose", "-V", default=1, type=int, help="Verbose option") + parser.add_argument( + "--allow-one-column", + type=strtobool, + default=False, + help="Allow one column in input scp files. " + "In this case, the value will be empty string.", ) + parser.add_argument( + "--out", + "-O", + type=str, + help="The output filename. " + "If omitted, then output to sys.stdout", ) + return parser + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + args.scps = [args.scps] + + # logging info + logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + if args.verbose > 0: + logging.basicConfig(level=logging.INFO, format=logfmt) + else: + logging.basicConfig(level=logging.WARN, format=logfmt) + logging.info(get_commandline_args()) + + # List[List[Tuple[str, str, Callable[[str], Any], str, str]]] + input_infos = [] + output_infos = [] + infos = [] + for lis_list, key_scps_list in [ + (input_infos, args.input_scps), + (output_infos, args.output_scps), + (infos, args.scps), + ]: + for key_scps in key_scps_list: + lis = [] + for key_scp in key_scps: + sps = key_scp.split(":") + if len(sps) == 2: + key, scp = sps + type_func = None + type_func_str = "none" + elif len(sps) == 3: + key, scp, type_func_str = sps + fail = False + + try: + # type_func: Callable[[str], Any] + # e.g. type_func_str = "int" -> type_func = int + type_func = eval(type_func_str) + except Exception: + raise RuntimeError( + "Unknown type: {}".format(type_func_str)) + + if not callable(type_func): + raise RuntimeError( + "Unknown type: {}".format(type_func_str)) + + else: + raise RuntimeError( + "Format : " + "or :: " + "e.g. feat:data/feat.scp " + "or shape:data/feat.scp:shape: {}".format(key_scp)) + + for item in lis: + if key == item[0]: + raise RuntimeError('The key "{}" is duplicated: {} {}'. + format(key, item[3], key_scp)) + + lis.append((key, scp, type_func, key_scp, type_func_str)) + lis_list.append(lis) + + # Open scp files + input_fscps = [[open(i[1], "r", encoding="utf-8") for i in il] + for il in input_infos] + output_fscps = [[open(i[1], "r", encoding="utf-8") for i in il] + for il in output_infos] + fscps = [[open(i[1], "r", encoding="utf-8") for i in il] for il in infos] + + # Note(kamo): What is done here? + # The final goal is creating a JSON file such as. + # { + # "utts": { + # "sample_id1": {(omitted)}, + # "sample_id2": {(omitted)}, + # .... + # } + # } + # + # To reduce memory usage, reading the input text files for each lines + # and writing JSON elements per samples. + if args.out is None: + out = sys.stdout + else: + out = open(args.out, "w", encoding="utf-8") + out.write('{\n "utts": {\n') + nutt = 0 + while True: + nutt += 1 + # List[List[str]] + input_lines = [[f.readline() for f in fl] for fl in input_fscps] + output_lines = [[f.readline() for f in fl] for fl in output_fscps] + lines = [[f.readline() for f in fl] for fl in fscps] + + # Get the first line + concat = sum(input_lines + output_lines + lines, []) + if len(concat) == 0: + break + first = concat[0] + + # Sanity check: Must be sorted by the first column and have same keys + count = 0 + for ls_list in (input_lines, output_lines, lines): + for ls in ls_list: + for line in ls: + if line == "" or first == "": + if line != first: + concat = sum(input_infos + output_infos + infos, []) + raise RuntimeError("The number of lines mismatch " + 'between: "{}" and "{}"'.format( + concat[0][1], + concat[count][1])) + + elif line.split()[0] != first.split()[0]: + concat = sum(input_infos + output_infos + infos, []) + raise RuntimeError( + "The keys are mismatch at {}th line " + 'between "{}" and "{}":\n>>> {}\n>>> {}'.format( + nutt, + concat[0][1], + concat[count][1], + first.rstrip(), + line.rstrip(), )) + count += 1 + + # The end of file + if first == "": + if nutt != 1: + out.write("\n") + break + if nutt != 1: + out.write(",\n") + + entry = {} + for inout, _lines, _infos in [ + ("input", input_lines, input_infos), + ("output", output_lines, output_infos), + ("other", lines, infos), + ]: + + lis = [] + for idx, (line_list, info_list) in enumerate( + zip(_lines, _infos), 1): + if inout == "input": + d = {"name": "input{}".format(idx)} + elif inout == "output": + d = {"name": "target{}".format(idx)} + else: + d = {} + + # info_list: List[Tuple[str, str, Callable]] + # line_list: List[str] + for line, info in zip(line_list, info_list): + sps = line.split(None, 1) + if len(sps) < 2: + if not args.allow_one_column: + raise RuntimeError( + "Format error {}th line in {}: " + ' Expecting " ":\n>>> {}'.format( + nutt, info[1], line)) + uttid = sps[0] + value = "" + else: + uttid, value = sps + + key = info[0] + type_func = info[2] + value = value.rstrip() + + if type_func is not None: + try: + # type_func: Callable[[str], Any] + value = type_func(value) + except Exception: + logging.error( + '"{}" is an invalid function ' + "for the {} th line in {}: \n>>> {}".format( + info[4], nutt, info[1], line)) + raise + + d[key] = value + lis.append(d) + + if inout != "other": + entry[inout] = lis + else: + # If key == 'other'. only has the first item + entry.update(lis[0]) + + entry = json.dumps( + entry, + indent=4, + ensure_ascii=False, + sort_keys=True, + separators=(",", ": ")) + # Add indent + indent = " " * 2 + entry = ("\n" + indent).join(entry.split("\n")) + + uttid = first.split()[0] + out.write(' "{}": {}'.format(uttid, entry)) + + out.write(" }\n}\n") + + logging.info("{} entries in {}".format(nutt, out.name)) diff --git a/utils/reduce_data_dir.sh b/utils/reduce_data_dir.sh new file mode 100755 index 00000000..60c82a7c --- /dev/null +++ b/utils/reduce_data_dir.sh @@ -0,0 +1,59 @@ +#!/usr/bin/env bash + +# koried, 10/29/2012 + +# Reduce a data set based on a list of turn-ids + +help_message="usage: $0 srcdir turnlist destdir" + +if [ $1 == "--help" ]; then + echo "${help_message}" + exit 0; +fi + +if [ $# != 3 ]; then + echo "${help_message}" + exit 1; +fi + +srcdir=$1 +reclist=$2 +destdir=$3 + +if [ ! -f ${srcdir}/utt2spk ]; then +echo "$0: no such file $srcdir/utt2spk" +exit 1; +fi + +function do_filtering { +# assumes the utt2spk and spk2utt files already exist. + [ -f ${srcdir}/feats.scp ] && utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/feats.scp >${destdir}/feats.scp + [ -f ${srcdir}/wav.scp ] && utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/wav.scp >${destdir}/wav.scp + [ -f ${srcdir}/text ] && utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/text >${destdir}/text + [ -f ${srcdir}/utt2num_frames ] && utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/utt2num_frames >${destdir}/utt2num_frames + [ -f ${srcdir}/spk2gender ] && utils/filter_scp.pl ${destdir}/spk2utt <${srcdir}/spk2gender >${destdir}/spk2gender + [ -f ${srcdir}/cmvn.scp ] && utils/filter_scp.pl ${destdir}/spk2utt <${srcdir}/cmvn.scp >${destdir}/cmvn.scp + if [ -f ${srcdir}/segments ]; then + utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/segments >${destdir}/segments + awk '{print $2;}' ${destdir}/segments | sort | uniq > ${destdir}/reco # recordings. + # The next line would override the command above for wav.scp, which would be incorrect. + [ -f ${srcdir}/wav.scp ] && utils/filter_scp.pl ${destdir}/reco <${srcdir}/wav.scp >${destdir}/wav.scp + [ -f ${srcdir}/reco2file_and_channel ] && \ + utils/filter_scp.pl ${destdir}/reco <${srcdir}/reco2file_and_channel >${destdir}/reco2file_and_channel + + # Filter the STM file for proper sclite scoring (this will also remove the comments lines) + [ -f ${srcdir}/stm ] && utils/filter_scp.pl ${destdir}/reco < ${srcdir}/stm > ${destdir}/stm + rm ${destdir}/reco + fi + srcutts=$(wc -l < ${srcdir}/utt2spk) + destutts=$(wc -l < ${destdir}/utt2spk) + echo "Reduced #utt from $srcutts to $destutts" +} + +mkdir -p ${destdir} + +# filter the utt2spk based on the set of recordings +utils/filter_scp.pl ${reclist} < ${srcdir}/utt2spk > ${destdir}/utt2spk + +utils/utt2spk_to_spk2utt.pl < ${destdir}/utt2spk > ${destdir}/spk2utt +do_filtering; diff --git a/utils/remove_longshortdata.sh b/utils/remove_longshortdata.sh new file mode 100755 index 00000000..e0b9da09 --- /dev/null +++ b/utils/remove_longshortdata.sh @@ -0,0 +1,62 @@ +#!/usr/bin/env bash + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +. ./path.sh + +maxframes=2000 +minframes=10 +maxchars=200 +minchars=0 +nlsyms="" +no_feat=false +trans_type=char + +help_message="usage: $0 olddatadir newdatadir" + +. utils/parse_options.sh || exit 1; + +if [ $# != 2 ]; then + echo "${help_message}" + exit 1; +fi + +sdir=$1 +odir=$2 +mkdir -p ${odir}/tmp + +if [ ${no_feat} = true ]; then + # for machine translation + cut -d' ' -f 1 ${sdir}/text > ${odir}/tmp/reclist1 +else + echo "extract utterances having less than $maxframes or more than $minframes frames" + utils/data/get_utt2num_frames.sh ${sdir} + < ${sdir}/utt2num_frames awk -v maxframes="$maxframes" '{ if ($2 < maxframes) print }' \ + | awk -v minframes="$minframes" '{ if ($2 > minframes) print }' \ + | awk '{print $1}' > ${odir}/tmp/reclist1 +fi + +echo "extract utterances having less than $maxchars or more than $minchars characters" +# counting number of chars. Use (NF - 1) instead of NF to exclude the utterance ID column +if [ -z ${nlsyms} ]; then +text2token.py -s 1 -n 1 ${sdir}/text --trans_type ${trans_type} \ + | awk -v maxchars="$maxchars" '{ if (NF - 1 < maxchars) print }' \ + | awk -v minchars="$minchars" '{ if (NF - 1 > minchars) print }' \ + | awk '{print $1}' > ${odir}/tmp/reclist2 +else +text2token.py -l ${nlsyms} -s 1 -n 1 ${sdir}/text --trans_type ${trans_type} \ + | awk -v maxchars="$maxchars" '{ if (NF - 1 < maxchars) print }' \ + | awk -v minchars="$minchars" '{ if (NF - 1 > minchars) print }' \ + | awk '{print $1}' > ${odir}/tmp/reclist2 +fi + +# extract common lines +comm -12 <(sort ${odir}/tmp/reclist1) <(sort ${odir}/tmp/reclist2) > ${odir}/tmp/reclist + +reduce_data_dir.sh ${sdir} ${odir}/tmp/reclist ${odir} +utils/fix_data_dir.sh ${odir} + +oldnum=$(wc -l ${sdir}/feats.scp | awk '{print $1}') +newnum=$(wc -l ${odir}/feats.scp | awk '{print $1}') +echo "change from $oldnum to $newnum" diff --git a/utils/text2token.py b/utils/text2token.py new file mode 100755 index 00000000..4b25612e --- /dev/null +++ b/utils/text2token.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +import argparse +import codecs +import re +import sys + +is_python2 = sys.version_info[0] == 2 + + +def exist_or_not(i, match_pos): + start_pos = None + end_pos = None + for pos in match_pos: + if pos[0] <= i < pos[1]: + start_pos = pos[0] + end_pos = pos[1] + break + + return start_pos, end_pos + + +def get_parser(): + parser = argparse.ArgumentParser( + description="convert raw text to tokenized text", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) + parser.add_argument( + "--nchar", + "-n", + default=1, + type=int, + help="number of characters to split, i.e., \ + aabb -> a a b b with -n 1 and aa bb with -n 2", ) + parser.add_argument( + "--skip-ncols", "-s", default=0, type=int, help="skip first n columns") + parser.add_argument( + "--space", default="", type=str, help="space symbol") + parser.add_argument( + "--non-lang-syms", + "-l", + default=None, + type=str, + help="list of non-linguistic symobles, e.g., etc.", ) + parser.add_argument( + "text", type=str, default=False, nargs="?", help="input text") + parser.add_argument( + "--trans_type", + "-t", + type=str, + default="char", + choices=["char", "phn"], + help="""Transcript type. char/phn. e.g., for TIMIT FADG0_SI1279 - + If trans_type is char, + read from SI1279.WRD file -> "bricks are an alternative" + Else if trans_type is phn, + read from SI1279.PHN file -> "sil b r ih sil k s aa r er n aa l + sil t er n ih sil t ih v sil" """, ) + return parser + + +def main(): + parser = get_parser() + args = parser.parse_args() + + rs = [] + if args.non_lang_syms is not None: + with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f: + nls = [x.rstrip() for x in f.readlines()] + rs = [re.compile(re.escape(x)) for x in nls] + + if args.text: + f = codecs.open(args.text, encoding="utf-8") + else: + f = codecs.getreader("utf-8")(sys.stdin + if is_python2 else sys.stdin.buffer) + + sys.stdout = codecs.getwriter("utf-8")(sys.stdout + if is_python2 else sys.stdout.buffer) + line = f.readline() + n = args.nchar + while line: + x = line.split() + print(" ".join(x[:args.skip_ncols]), end=" ") + a = " ".join(x[args.skip_ncols:]) + + # get all matched positions + match_pos = [] + for r in rs: + i = 0 + while i >= 0: + m = r.search(a, i) + if m: + match_pos.append([m.start(), m.end()]) + i = m.end() + else: + break + + if args.trans_type == "phn": + a = a.split(" ") + else: + if len(match_pos) > 0: + chars = [] + i = 0 + while i < len(a): + start_pos, end_pos = exist_or_not(i, match_pos) + if start_pos is not None: + chars.append(a[start_pos:end_pos]) + i = end_pos + else: + chars.append(a[i]) + i += 1 + a = chars + + a = [a[j:j + n] for j in range(0, len(a), n)] + + a_flat = [] + for z in a: + a_flat.append("".join(z)) + + a_chars = [z.replace(" ", args.space) for z in a_flat] + if args.trans_type == "phn": + a_chars = [z.replace("sil", args.space) for z in a_chars] + print(" ".join(a_chars)) + line = f.readline() + + +if __name__ == "__main__": + main()