more utils to support kaldi/espnet data preocess

pull/931/head
Hui Zhang 3 years ago
parent ed19e243de
commit 871fc5b70d

1
.gitignore vendored

@ -24,5 +24,6 @@ tools/montreal-forced-aligner/
tools/Montreal-Forced-Aligner/ tools/Montreal-Forced-Aligner/
tools/sctk tools/sctk
tools/sctk-20159b5/ tools/sctk-20159b5/
tools/kaldi
*output/ *output/

@ -318,6 +318,18 @@ class CTCPrefixScore():
r[0, 0] = xs[0] r[0, 0] = xs[0]
r[0, 1] = self.logzero r[0, 1] = self.logzero
else: else:
# Although the code does not exactly follow Algorithm 2,
# we don't have to change it because we can assume
# r_t(h)=0 for t < |h| in CTC forward computation
# (Note: we assume here that index t starts with 0).
# The purpose of this difference is to reduce the number of for-loops.
# https://github.com/espnet/espnet/pull/3655
# where we start to accumulate r_t(h) from t=|h|
# and iterate r_t(h) = (r_{t-1}(h) + ...) to T-1,
# avoiding accumulating zeros for t=1~|h|-1.
# Thus, we need to set r_{|h|-1}(h) = 0,
# i.e., r[output_length-1] = logzero, for initialization.
# This is just for reducing the computation.
r[output_length - 1] = self.logzero r[output_length - 1] = self.logzero
# prepare forward probabilities for the last label # prepare forward probabilities for the last label

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains the data augmentation pipeline.""" """Contains the data augmentation pipeline."""
import json import json
import os
from collections.abc import Sequence from collections.abc import Sequence
from inspect import signature from inspect import signature
from pprint import pformat from pprint import pformat
@ -90,9 +91,8 @@ class AugmentationPipeline():
effect. effect.
Params: Params:
augmentation_config(str): Augmentation configuration in json string. preprocess_conf(str): Augmentation configuration in `json file` or `json string`.
random_seed(int): Random seed. random_seed(int): Random seed.
train(bool): whether is train mode.
Raises: Raises:
ValueError: If the augmentation json config is in incorrect format". ValueError: If the augmentation json config is in incorrect format".
@ -100,11 +100,18 @@ class AugmentationPipeline():
SPEC_TYPES = {'specaug'} SPEC_TYPES = {'specaug'}
def __init__(self, augmentation_config: str, random_seed: int=0): def __init__(self, preprocess_conf: str, random_seed: int=0):
self._rng = np.random.RandomState(random_seed) self._rng = np.random.RandomState(random_seed)
self.conf = {'mode': 'sequential', 'process': []} self.conf = {'mode': 'sequential', 'process': []}
if augmentation_config: if preprocess_conf:
process = json.loads(augmentation_config) if os.path.isfile(preprocess_conf):
# json file
with open(preprocess_conf, 'r') as fin:
json_string = fin.read()
else:
# json string
json_string = preprocess_conf
process = json.loads(json_string)
self.conf['process'] += process self.conf['process'] += process
self._augmentors, self._rates = self._parse_pipeline_from('all') self._augmentors, self._rates = self._parse_pipeline_from('all')

@ -105,7 +105,7 @@ class SpeechCollatorBase():
self._local_data = TarLocalData(tar2info={}, tar2object={}) self._local_data = TarLocalData(tar2info={}, tar2object={})
self.augmentation = AugmentationPipeline( self.augmentation = AugmentationPipeline(
augmentation_config=aug_file.read(), random_seed=random_seed) preprocess_conf=aug_file.read(), random_seed=random_seed)
self._normalizer = FeatureNormalizer( self._normalizer = FeatureNormalizer(
mean_std_filepath) if mean_std_filepath else None mean_std_filepath) if mean_std_filepath else None

@ -17,14 +17,13 @@ import kaldiio
import numpy as np import numpy as np
import soundfile import soundfile
from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline as Transformation
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
__all__ = ["LoadInputsAndTargets"] __all__ = ["LoadInputsAndTargets"]
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
class LoadInputsAndTargets(): class LoadInputsAndTargets():
"""Create a mini-batch from a list of dicts """Create a mini-batch from a list of dicts
@ -66,8 +65,7 @@ class LoadInputsAndTargets():
raise ValueError("Only asr are allowed: mode={}".format(mode)) raise ValueError("Only asr are allowed: mode={}".format(mode))
if preprocess_conf is not None: if preprocess_conf is not None:
with open(preprocess_conf, 'r') as fin: self.preprocessing = Transformation(preprocess_conf)
self.preprocessing = AugmentationPipeline(fin.read())
logger.warning( logger.warning(
"[Experimental feature] Some preprocessing will be done " "[Experimental feature] Some preprocessing will be done "
"for the mini-batch creation using {}".format( "for the mini-batch creation using {}".format(

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from typing import Any from typing import Any
from typing import List from typing import List
from typing import Tuple from typing import Tuple
@ -26,6 +25,9 @@ from deepspeech.models.lm_interface import LMInterface
from deepspeech.modules.encoder import TransformerEncoder from deepspeech.modules.encoder import TransformerEncoder
from deepspeech.modules.mask import subsequent_mask from deepspeech.modules.mask import subsequent_mask
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
def __init__( def __init__(
@ -74,10 +76,10 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
self.decoder = nn.Linear(att_unit, n_vocab) self.decoder = nn.Linear(att_unit, n_vocab)
logging.info("Tie weights set to {}".format(tie_weights)) logger.info("Tie weights set to {}".format(tie_weights))
logging.info("Dropout set to {}".format(dropout_rate)) logger.info("Dropout set to {}".format(dropout_rate))
logging.info("Emb Dropout set to {}".format(emb_dropout_rate)) logger.info("Emb Dropout set to {}".format(emb_dropout_rate))
logging.info("Att Dropout set to {}".format(att_dropout_rate)) logger.info("Att Dropout set to {}".format(att_dropout_rate))
if tie_weights: if tie_weights:
assert ( assert (

@ -23,17 +23,39 @@ from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
__all__ = [ __all__ = [
"NoPositionalEncoding", "PositionalEncoding", "RelPositionalEncoding" "PositionalEncodingInterface", "NoPositionalEncoding", "PositionalEncoding", "RelPositionalEncoding"
] ]
class PositionalEncodingInterface:
class NoPositionalEncoding(nn.Layer): def forward(self, x:paddle.Tensor, offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Compute positional encoding.
Args:
x (paddle.Tensor): Input tensor (batch, time, `*`).
Returns:
paddle.Tensor: Encoded tensor (batch, time, `*`).
paddle.Tensor: Positional embedding tensor (1, time, `*`).
"""
raise NotImplementedError("forward method is not implemented")
def position_encoding(self, offset:int, size:int) -> paddle.Tensor:
""" For getting encoding in a streaming fashion
Args:
offset (int): start offset
size (int): requried size of position encoding
Returns:
paddle.Tensor: Corresponding position encoding
"""
raise NotImplementedError("position_encoding method is not implemented")
class NoPositionalEncoding(nn.Layer, PositionalEncodingInterface):
def __init__(self, def __init__(self,
d_model: int, d_model: int,
dropout_rate: float, dropout_rate: float,
max_len: int=5000, max_len: int=5000,
reverse: bool=False): reverse: bool=False):
super().__init__() nn.Layer.__init__(self)
def forward(self, x: paddle.Tensor, def forward(self, x: paddle.Tensor,
offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]: offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]:
@ -43,7 +65,7 @@ class NoPositionalEncoding(nn.Layer):
return None return None
class PositionalEncoding(nn.Layer): class PositionalEncoding(nn.Layer, PositionalEncodingInterface):
def __init__(self, def __init__(self,
d_model: int, d_model: int,
dropout_rate: float, dropout_rate: float,
@ -58,7 +80,7 @@ class PositionalEncoding(nn.Layer):
max_len (int, optional): maximum input length. Defaults to 5000. max_len (int, optional): maximum input length. Defaults to 5000.
reverse (bool, optional): Not used. Defaults to False. reverse (bool, optional): Not used. Defaults to False.
""" """
super().__init__() nn.Layer.__init__(self)
self.d_model = d_model self.d_model = d_model
self.max_len = max_len self.max_len = max_len
self.xscale = paddle.to_tensor(math.sqrt(self.d_model)) self.xscale = paddle.to_tensor(math.sqrt(self.d_model))
@ -103,7 +125,7 @@ class PositionalEncoding(nn.Layer):
offset (int): start offset offset (int): start offset
size (int): requried size of position encoding size (int): requried size of position encoding
Returns: Returns:
paddle.Tensor: Corresponding encoding paddle.Tensor: Corresponding position encoding
""" """
assert offset + size < self.max_len assert offset + size < self.max_len
return self.dropout(self.pe[:, offset:offset + size]) return self.dropout(self.pe[:, offset:offset + size])

@ -0,0 +1,149 @@
import io
import h5py
import kaldiio
import numpy as np
class CMVN():
"Apply Global/Spk CMVN/iverserCMVN."
def __init__(
self,
stats,
norm_means=True,
norm_vars=False,
filetype="mat",
utt2spk=None,
spk2utt=None,
reverse=False,
std_floor=1.0e-20,
):
self.stats_file = stats
self.norm_means = norm_means
self.norm_vars = norm_vars
self.reverse = reverse
if isinstance(stats, dict):
stats_dict = dict(stats)
else:
# Use for global CMVN
if filetype == "mat":
stats_dict = {None: kaldiio.load_mat(stats)}
# Use for global CMVN
elif filetype == "npy":
stats_dict = {None: np.load(stats)}
# Use for speaker CMVN
elif filetype == "ark":
self.accept_uttid = True
stats_dict = dict(kaldiio.load_ark(stats))
# Use for speaker CMVN
elif filetype == "hdf5":
self.accept_uttid = True
stats_dict = h5py.File(stats)
else:
raise ValueError("Not supporting filetype={}".format(filetype))
if utt2spk is not None:
self.utt2spk = {}
with io.open(utt2spk, "r", encoding="utf-8") as f:
for line in f:
utt, spk = line.rstrip().split(None, 1)
self.utt2spk[utt] = spk
elif spk2utt is not None:
self.utt2spk = {}
with io.open(spk2utt, "r", encoding="utf-8") as f:
for line in f:
spk, utts = line.rstrip().split(None, 1)
for utt in utts.split():
self.utt2spk[utt] = spk
else:
self.utt2spk = None
# Kaldi makes a matrix for CMVN which has a shape of (2, feat_dim + 1),
# and the first vector contains the sum of feats and the second is
# the sum of squares. The last value of the first, i.e. stats[0,-1],
# is the number of samples for this statistics.
self.bias = {}
self.scale = {}
for spk, stats in stats_dict.items():
assert len(stats) == 2, stats.shape
count = stats[0, -1]
# If the feature has two or more dimensions
if not (np.isscalar(count) or isinstance(count, (int, float))):
# The first is only used
count = count.flatten()[0]
mean = stats[0, :-1] / count
# V(x) = E(x^2) - (E(x))^2
var = stats[1, :-1] / count - mean * mean
std = np.maximum(np.sqrt(var), std_floor)
self.bias[spk] = -mean
self.scale[spk] = 1 / std
def __repr__(self):
return (
"{name}(stats_file={stats_file}, "
"norm_means={norm_means}, norm_vars={norm_vars}, "
"reverse={reverse})".format(
name=self.__class__.__name__,
stats_file=self.stats_file,
norm_means=self.norm_means,
norm_vars=self.norm_vars,
reverse=self.reverse,
)
)
def __call__(self, x, uttid=None):
if self.utt2spk is not None:
spk = self.utt2spk[uttid]
else:
spk = uttid
if not self.reverse:
# apply cmvn
if self.norm_means:
x = np.add(x, self.bias[spk])
if self.norm_vars:
x = np.multiply(x, self.scale[spk])
else:
# apply reverse cmvn
if self.norm_vars:
x = np.divide(x, self.scale[spk])
if self.norm_means:
x = np.subtract(x, self.bias[spk])
return x
class UtteranceCMVN():
"Apply Utterance CMVN"
def __init__(self, norm_means=True, norm_vars=False, std_floor=1.0e-20):
self.norm_means = norm_means
self.norm_vars = norm_vars
self.std_floor = std_floor
def __repr__(self):
return "{name}(norm_means={norm_means}, norm_vars={norm_vars})".format(
name=self.__class__.__name__,
norm_means=self.norm_means,
norm_vars=self.norm_vars,
)
def __call__(self, x, uttid=None):
# x: [Time, Dim]
square_sums = (x ** 2).sum(axis=0)
mean = x.mean(axis=0)
if self.norm_means:
x = np.subtract(x, mean)
if self.norm_vars:
var = square_sums / x.shape[0] - mean ** 2
std = np.maximum(np.sqrt(var), self.std_floor)
x = np.divide(x, std)
return x

@ -0,0 +1,237 @@
import io
import logging
import sys
import h5py
import kaldiio
import soundfile
from deepspeech.io.reader import SoundHDF5File
def file_reader_helper(
rspecifier: str,
filetype: str = "mat",
return_shape: bool = False,
segments: str = None,
):
"""Read uttid and array in kaldi style
This function might be a bit confusing as "ark" is used
for HDF5 to imitate "kaldi-rspecifier".
Args:
rspecifier: Give as "ark:feats.ark" or "scp:feats.scp"
filetype: "mat" is kaldi-martix, "hdf5": HDF5
return_shape: Return the shape of the matrix,
instead of the matrix. This can reduce IO cost for HDF5.
segments (str): The file format is
"<segment-id> <recording-id> <start-time> <end-time>\n"
"e.g. call-861225-A-0050-0065 call-861225-A 5.0 6.5\n"
Returns:
Generator[Tuple[str, np.ndarray], None, None]:
Examples:
Read from kaldi-matrix ark file:
>>> for u, array in file_reader_helper('ark:feats.ark', 'mat'):
... array
Read from HDF5 file:
>>> for u, array in file_reader_helper('ark:feats.h5', 'hdf5'):
... array
"""
if filetype == "mat":
return KaldiReader(rspecifier, return_shape=return_shape, segments=segments)
elif filetype == "hdf5":
return HDF5Reader(rspecifier, return_shape=return_shape)
elif filetype == "sound.hdf5":
return SoundHDF5Reader(rspecifier, return_shape=return_shape)
elif filetype == "sound":
return SoundReader(rspecifier, return_shape=return_shape)
else:
raise NotImplementedError(f"filetype={filetype}")
class KaldiReader:
def __init__(self, rspecifier, return_shape=False, segments=None):
self.rspecifier = rspecifier
self.return_shape = return_shape
self.segments = segments
def __iter__(self):
with kaldiio.ReadHelper(self.rspecifier, segments=self.segments) as reader:
for key, array in reader:
if self.return_shape:
array = array.shape
yield key, array
class HDF5Reader:
def __init__(self, rspecifier, return_shape=False):
if ":" not in rspecifier:
raise ValueError(
'Give "rspecifier" such as "ark:some.ark: {}"'.format(self.rspecifier)
)
self.rspecifier = rspecifier
self.ark_or_scp, self.filepath = self.rspecifier.split(":", 1)
if self.ark_or_scp not in ["ark", "scp"]:
raise ValueError(f"Must be scp or ark: {self.ark_or_scp}")
self.return_shape = return_shape
def __iter__(self):
if self.ark_or_scp == "scp":
hdf5_dict = {}
with open(self.filepath, "r", encoding="utf-8") as f:
for line in f:
key, value = line.rstrip().split(None, 1)
if ":" not in value:
raise RuntimeError(
"scp file for hdf5 should be like: "
'"uttid filepath.h5:key": {}({})'.format(
line, self.filepath
)
)
path, h5_key = value.split(":", 1)
hdf5_file = hdf5_dict.get(path)
if hdf5_file is None:
try:
hdf5_file = h5py.File(path, "r")
except Exception:
logging.error("Error when loading {}".format(path))
raise
hdf5_dict[path] = hdf5_file
try:
data = hdf5_file[h5_key]
except Exception:
logging.error(
"Error when loading {} with key={}".format(path, h5_key)
)
raise
if self.return_shape:
yield key, data.shape
else:
yield key, data[()]
# Closing all files
for k in hdf5_dict:
try:
hdf5_dict[k].close()
except Exception:
pass
else:
if self.filepath == "-":
# Required h5py>=2.9
filepath = io.BytesIO(sys.stdin.buffer.read())
else:
filepath = self.filepath
with h5py.File(filepath, "r") as f:
for key in f:
if self.return_shape:
yield key, f[key].shape
else:
yield key, f[key][()]
class SoundHDF5Reader:
def __init__(self, rspecifier, return_shape=False):
if ":" not in rspecifier:
raise ValueError(
'Give "rspecifier" such as "ark:some.ark: {}"'.format(rspecifier)
)
self.ark_or_scp, self.filepath = rspecifier.split(":", 1)
if self.ark_or_scp not in ["ark", "scp"]:
raise ValueError(f"Must be scp or ark: {self.ark_or_scp}")
self.return_shape = return_shape
def __iter__(self):
if self.ark_or_scp == "scp":
hdf5_dict = {}
with open(self.filepath, "r", encoding="utf-8") as f:
for line in f:
key, value = line.rstrip().split(None, 1)
if ":" not in value:
raise RuntimeError(
"scp file for hdf5 should be like: "
'"uttid filepath.h5:key": {}({})'.format(
line, self.filepath
)
)
path, h5_key = value.split(":", 1)
hdf5_file = hdf5_dict.get(path)
if hdf5_file is None:
try:
hdf5_file = SoundHDF5File(path, "r")
except Exception:
logging.error("Error when loading {}".format(path))
raise
hdf5_dict[path] = hdf5_file
try:
data = hdf5_file[h5_key]
except Exception:
logging.error(
"Error when loading {} with key={}".format(path, h5_key)
)
raise
# Change Tuple[ndarray, int] -> Tuple[int, ndarray]
# (soundfile style -> scipy style)
array, rate = data
if self.return_shape:
array = array.shape
yield key, (rate, array)
# Closing all files
for k in hdf5_dict:
try:
hdf5_dict[k].close()
except Exception:
pass
else:
if self.filepath == "-":
# Required h5py>=2.9
filepath = io.BytesIO(sys.stdin.buffer.read())
else:
filepath = self.filepath
for key, (a, r) in SoundHDF5File(filepath, "r").items():
if self.return_shape:
a = a.shape
yield key, (r, a)
class SoundReader:
def __init__(self, rspecifier, return_shape=False):
if ":" not in rspecifier:
raise ValueError(
'Give "rspecifier" such as "scp:some.scp: {}"'.format(rspecifier)
)
self.ark_or_scp, self.filepath = rspecifier.split(":", 1)
if self.ark_or_scp != "scp":
raise ValueError(
'Only supporting "scp" for sound file: {}'.format(self.ark_or_scp)
)
self.return_shape = return_shape
def __iter__(self):
with open(self.filepath, "r", encoding="utf-8") as f:
for line in f:
key, sound_file_path = line.rstrip().split(None, 1)
# Assume PCM16
array, rate = soundfile.read(sound_file_path, dtype="int16")
# Change Tuple[ndarray, int] -> Tuple[int, ndarray]
# (soundfile style -> scipy style)
if self.return_shape:
array = array.shape
yield key, (rate, array)

@ -0,0 +1,65 @@
from collections.abc import Sequence
from distutils.util import strtobool as dist_strtobool
import sys
import numpy
def strtobool(x):
# distutils.util.strtobool returns integer, but it's confusing,
return bool(dist_strtobool(x))
def get_commandline_args():
extra_chars = [
" ",
";",
"&",
"(",
")",
"|",
"^",
"<",
">",
"?",
"*",
"[",
"]",
"$",
"`",
'"',
"\\",
"!",
"{",
"}",
]
# Escape the extra characters for shell
argv = [
arg.replace("'", "'\\''")
if all(char not in arg for char in extra_chars)
else "'" + arg.replace("'", "'\\''") + "'"
for arg in sys.argv
]
return sys.executable + " " + " ".join(argv)
def is_scipy_wav_style(value):
# If Tuple[int, numpy.ndarray] or not
return (
isinstance(value, Sequence)
and len(value) == 2
and isinstance(value[0], int)
and isinstance(value[1], numpy.ndarray)
)
def assert_scipy_wav_style(value):
assert is_scipy_wav_style(
value
), "Must be Tuple[int, numpy.ndarray], but got {}".format(
type(value)
if not isinstance(value, Sequence)
else "{}[{}]".format(type(value), ", ".join(str(type(v)) for v in value))
)

@ -0,0 +1,282 @@
from pathlib import Path
from typing import Dict
import h5py
import kaldiio
import numpy
import soundfile
from deepspeech.utils.cli_utils import assert_scipy_wav_style
from deepspeech.io.reader import SoundHDF5File
def file_writer_helper(
wspecifier: str,
filetype: str = "mat",
write_num_frames: str = None,
compress: bool = False,
compression_method: int = 2,
pcm_format: str = "wav",
):
"""Write matrices in kaldi style
Args:
wspecifier: e.g. ark,scp:out.ark,out.scp
filetype: "mat" is kaldi-martix, "hdf5": HDF5
write_num_frames: e.g. 'ark,t:num_frames.txt'
compress: Compress or not
compression_method: Specify compression level
Write in kaldi-matrix-ark with "kaldi-scp" file:
>>> with file_writer_helper('ark,scp:out.ark,out.scp') as f:
>>> f['uttid'] = array
This "scp" has the following format:
uttidA out.ark:1234
uttidB out.ark:2222
where, 1234 and 2222 points the strating byte address of the matrix.
(For detail, see official documentation of Kaldi)
Write in HDF5 with "scp" file:
>>> with file_writer_helper('ark,scp:out.h5,out.scp', 'hdf5') as f:
>>> f['uttid'] = array
This "scp" file is created as:
uttidA out.h5:uttidA
uttidB out.h5:uttidB
HDF5 can be, unlike "kaldi-ark", accessed to any keys,
so originally "scp" is not required for random-reading.
Nevertheless we create "scp" for HDF5 because it is useful
for some use-case. e.g. Concatenation, Splitting.
"""
if filetype == "mat":
return KaldiWriter(
wspecifier,
write_num_frames=write_num_frames,
compress=compress,
compression_method=compression_method,
)
elif filetype == "hdf5":
return HDF5Writer(
wspecifier, write_num_frames=write_num_frames, compress=compress
)
elif filetype == "sound.hdf5":
return SoundHDF5Writer(
wspecifier, write_num_frames=write_num_frames, pcm_format=pcm_format
)
elif filetype == "sound":
return SoundWriter(
wspecifier, write_num_frames=write_num_frames, pcm_format=pcm_format
)
else:
raise NotImplementedError(f"filetype={filetype}")
class BaseWriter:
def __setitem__(self, key, value):
raise NotImplementedError
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def close(self):
try:
self.writer.close()
except Exception:
pass
if self.writer_scp is not None:
try:
self.writer_scp.close()
except Exception:
pass
if self.writer_nframe is not None:
try:
self.writer_nframe.close()
except Exception:
pass
def get_num_frames_writer(write_num_frames: str):
"""get_num_frames_writer
Examples:
>>> get_num_frames_writer('ark,t:num_frames.txt')
"""
if write_num_frames is not None:
if ":" not in write_num_frames:
raise ValueError(
'Must include ":", write_num_frames={}'.format(write_num_frames)
)
nframes_type, nframes_file = write_num_frames.split(":", 1)
if nframes_type != "ark,t":
raise ValueError(
"Only supporting text mode. "
"e.g. --write-num-frames=ark,t:foo.txt :"
"{}".format(nframes_type)
)
return open(nframes_file, "w", encoding="utf-8")
class KaldiWriter(BaseWriter):
def __init__(
self, wspecifier, write_num_frames=None, compress=False, compression_method=2
):
if compress:
self.writer = kaldiio.WriteHelper(
wspecifier, compression_method=compression_method
)
else:
self.writer = kaldiio.WriteHelper(wspecifier)
self.writer_scp = None
if write_num_frames is not None:
self.writer_nframe = get_num_frames_writer(write_num_frames)
else:
self.writer_nframe = None
def __setitem__(self, key, value):
self.writer[key] = value
if self.writer_nframe is not None:
self.writer_nframe.write(f"{key} {len(value)}\n")
def parse_wspecifier(wspecifier: str) -> Dict[str, str]:
"""Parse wspecifier to dict
Examples:
>>> parse_wspecifier('ark,scp:out.ark,out.scp')
{'ark': 'out.ark', 'scp': 'out.scp'}
"""
ark_scp, filepath = wspecifier.split(":", 1)
if ark_scp not in ["ark", "scp,ark", "ark,scp"]:
raise ValueError("{} is not allowed: {}".format(ark_scp, wspecifier))
ark_scps = ark_scp.split(",")
filepaths = filepath.split(",")
if len(ark_scps) != len(filepaths):
raise ValueError("Mismatch: {} and {}".format(ark_scp, filepath))
spec_dict = dict(zip(ark_scps, filepaths))
return spec_dict
class HDF5Writer(BaseWriter):
"""HDF5Writer
Examples:
>>> with HDF5Writer('ark:out.h5', compress=True) as f:
... f['key'] = array
"""
def __init__(self, wspecifier, write_num_frames=None, compress=False):
spec_dict = parse_wspecifier(wspecifier)
self.filename = spec_dict["ark"]
if compress:
self.kwargs = {"compression": "gzip"}
else:
self.kwargs = {}
self.writer = h5py.File(spec_dict["ark"], "w")
if "scp" in spec_dict:
self.writer_scp = open(spec_dict["scp"], "w", encoding="utf-8")
else:
self.writer_scp = None
if write_num_frames is not None:
self.writer_nframe = get_num_frames_writer(write_num_frames)
else:
self.writer_nframe = None
def __setitem__(self, key, value):
self.writer.create_dataset(key, data=value, **self.kwargs)
if self.writer_scp is not None:
self.writer_scp.write(f"{key} {self.filename}:{key}\n")
if self.writer_nframe is not None:
self.writer_nframe.write(f"{key} {len(value)}\n")
class SoundHDF5Writer(BaseWriter):
"""SoundHDF5Writer
Examples:
>>> fs = 16000
>>> with SoundHDF5Writer('ark:out.h5') as f:
... f['key'] = fs, array
"""
def __init__(self, wspecifier, write_num_frames=None, pcm_format="wav"):
self.pcm_format = pcm_format
spec_dict = parse_wspecifier(wspecifier)
self.filename = spec_dict["ark"]
self.writer = SoundHDF5File(spec_dict["ark"], "w", format=self.pcm_format)
if "scp" in spec_dict:
self.writer_scp = open(spec_dict["scp"], "w", encoding="utf-8")
else:
self.writer_scp = None
if write_num_frames is not None:
self.writer_nframe = get_num_frames_writer(write_num_frames)
else:
self.writer_nframe = None
def __setitem__(self, key, value):
assert_scipy_wav_style(value)
# Change Tuple[int, ndarray] -> Tuple[ndarray, int]
# (scipy style -> soundfile style)
value = (value[1], value[0])
self.writer.create_dataset(key, data=value)
if self.writer_scp is not None:
self.writer_scp.write(f"{key} {self.filename}:{key}\n")
if self.writer_nframe is not None:
self.writer_nframe.write(f"{key} {len(value[0])}\n")
class SoundWriter(BaseWriter):
"""SoundWriter
Examples:
>>> fs = 16000
>>> with SoundWriter('ark,scp:outdir,out.scp') as f:
... f['key'] = fs, array
"""
def __init__(self, wspecifier, write_num_frames=None, pcm_format="wav"):
self.pcm_format = pcm_format
spec_dict = parse_wspecifier(wspecifier)
# e.g. ark,scp:dirname,wav.scp
# -> The wave files are found in dirname/*.wav
self.dirname = spec_dict["ark"]
Path(self.dirname).mkdir(parents=True, exist_ok=True)
self.writer = None
if "scp" in spec_dict:
self.writer_scp = open(spec_dict["scp"], "w", encoding="utf-8")
else:
self.writer_scp = None
if write_num_frames is not None:
self.writer_nframe = get_num_frames_writer(write_num_frames)
else:
self.writer_nframe = None
def __setitem__(self, key, value):
assert_scipy_wav_style(value)
rate, signal = value
wavfile = Path(self.dirname) / (key + "." + self.pcm_format)
soundfile.write(wavfile, signal.astype(numpy.int16), rate)
if self.writer_scp is not None:
self.writer_scp.write(f"{key} {wavfile}\n")
if self.writer_nframe is not None:
self.writer_nframe.write(f"{key} {len(signal)}\n")

@ -2,6 +2,20 @@
stage=-1 stage=-1
stop_stage=100 stop_stage=100
nj=32
debugmode=1
dumpdir=dump # directory to dump full features
N=0 # number of minibatches to be used (mainly for debugging). "0" uses all minibatches.
verbose=0 # verbose option
resume= # Resume the training from snapshot
# feature configuration
do_delta=false
# Set this to somewhere where you want to put your data, or where
# someone else has already put it. You'll want to change this
# if you're not on the CLSP grid.
datadir=${MAIN_ROOT}/examples/dataset/
# bpemode (unigram or bpe) # bpemode (unigram or bpe)
nbpe=5000 nbpe=5000
@ -10,11 +24,21 @@ bpeprefix="data/bpe_${bpemode}_${nbpe}"
source ${MAIN_ROOT}/utils/parse_options.sh source ${MAIN_ROOT}/utils/parse_options.sh
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail
train_set=train_960
train_sp=train_sp
train_dev=dev
recog_set="test_clean test_other dev_clean dev_other"
mkdir -p data mkdir -p data
TARGET_DIR=${MAIN_ROOT}/examples/dataset TARGET_DIR=${MAIN_ROOT}/examples/dataset
mkdir -p ${TARGET_DIR} mkdir -p ${TARGET_DIR}
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
# download data, generate manifests # download data, generate manifests
python3 ${TARGET_DIR}/librispeech/librispeech.py \ python3 ${TARGET_DIR}/librispeech/librispeech.py \
@ -46,43 +70,89 @@ if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
fi fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# compute mean and stddev for normalizer ### Task dependent. You have to make data the following preparation part by yourself.
num_workers=$(nproc) ### But you can utilize Kaldi recipes in most cases
python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ echo "stage 0: Data preparation"
--manifest_path="data/manifest.train.raw" \ for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
--num_samples=-1 \ # use underscore-separated names in data directories.
--spectrum_type="fbank" \ local/data_prep.sh ${datadir}/librispeech/${part}/LibriSpeech/${part} data/${part//-/_}
--feat_dim=80 \ done
--delta_delta=false \
--sample_rate=16000 \
--stride_ms=10.0 \
--window_ms=25.0 \
--use_dB_normalization=False \
--num_workers=${num_workers} \
--output_path="data/mean_std.json"
if [ $? -ne 0 ]; then
echo "Compute mean and stddev failed. Terminated."
exit 1
fi
fi fi
feat_tr_dir=${dumpdir}/${train_set}/delta${do_delta}; mkdir -p ${feat_tr_dir}
feat_sp_dir=${dumpdir}/${train_sp}/delta${do_delta}; mkdir -p ${feat_sp_dir}
feat_dt_dir=${dumpdir}/${train_dev}/delta${do_delta}; mkdir -p ${feat_dt_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# build vocabulary ### Task dependent. You have to design training and dev sets by yourself.
python3 ${MAIN_ROOT}/utils/build_vocab.py \ ### But you can utilize Kaldi recipes in most cases
--unit_type "spm" \ echo "stage 1: Feature Generation"
--spm_vocab_size=${nbpe} \ fbankdir=fbank
--spm_mode ${bpemode} \ # Generate the fbank features; by default 80-dimensional fbanks with pitch on each frame
--spm_model_prefix ${bpeprefix} \ for x in dev_clean test_clean dev_other test_other train_clean_100 train_clean_360 train_other_500; do
--vocab_path="data/vocab.txt" \ steps/make_fbank_pitch.sh --cmd "$train_cmd" --nj ${nj} --write_utt2num_frames true \
--manifest_paths="data/manifest.train.raw" data/${x} exp/make_fbank/${x} ${fbankdir}
utils/fix_data_dir.sh data/${x}
done
if [ $? -ne 0 ]; then utils/combine_data.sh --extra_files utt2num_frames data/${train_set}_org data/train_clean_100 data/train_clean_360 data/train_other_500
echo "Build vocabulary failed. Terminated." utils/combine_data.sh --extra_files utt2num_frames data/${train_dev}_org data/dev_clean data/dev_other
exit 1 utils/perturb_data_dir_speed.sh 0.9 data/${train_set}_org data/temp1
fi utils/perturb_data_dir_speed.sh 1.0 data/${train_set}_org data/temp2
utils/perturb_data_dir_speed.sh 1.1 data/${train_set}_org data/temp3
utils/combine_data.sh --extra-files utt2uniq data/${train_sp}_org data/temp1 data/temp2 data/temp3
# remove utt having more than 3000 frames
# remove utt having more than 400 characters
remove_longshortdata.sh --maxframes 3000 --maxchars 400 data/${train_set}_org data/${train_set}
remove_longshortdata.sh --maxframes 3000 --maxchars 400 data/${train_sp}_org data/${train_sp}
remove_longshortdata.sh --maxframes 3000 --maxchars 400 data/${train_dev}_org data/${train_dev}
steps/make_fbank_pitch.sh --cmd "$train_cmd" --nj $nj --write_utt2num_frames true \
data/train_sp exp/make_fbank/train_sp ${fbankdir}
utils/fix_data_dir.sh data/train_sp
# compute global CMVN
compute-cmvn-stats scp:data/${train_sp}/feats.scp data/${train_sp}/cmvn.ark
# dump features for training
dump.sh --cmd "$train_cmd" --nj ${nj} --do_delta ${do_delta} \
data/${train_sp}/feats.scp data/${train_sp}/cmvn.ark exp/dump_feats/train ${feat_sp_dir}
dump.sh --cmd "$train_cmd" --nj ${nj} --do_delta ${do_delta} \
data/${train_dev}/feats.scp data/${train_sp}/cmvn.ark exp/dump_feats/dev ${feat_dt_dir}
for rtask in ${recog_set}; do
feat_recog_dir=${dumpdir}/${rtask}/delta${do_delta}; mkdir -p ${feat_recog_dir}
dump.sh --cmd "$train_cmd" --nj ${nj} --do_delta ${do_delta} \
data/${rtask}/feats.scp data/${train_sp}/cmvn.ark exp/dump_feats/recog/${rtask} \
${feat_recog_dir}
done
fi fi
dict=data/lang_char/${train_set}_${bpemode}${nbpe}_units.txt
bpemodel=data/lang_char/${train_set}_${bpemode}${nbpe}
echo "dictionary: ${dict}"
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
### Task dependent. You have to check non-linguistic symbols used in the corpus.
echo "stage 2: Dictionary and Json Data Preparation"
mkdir -p data/lang_char/
echo "<unk> 1" > ${dict} # <unk> must be 1, 0 will be used for "blank" in CTC
cut -f 2- -d" " data/${train_set}/text > data/lang_char/input.txt
spm_train --input=data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000
spm_encode --model=${bpemodel}.model --output_format=piece < data/lang_char/input.txt | tr ' ' '\n' | sort | uniq | awk '{print $0 " " NR+1}' >> ${dict}
wc -l ${dict}
# make json labels
data2json.sh --nj ${nj} --feat ${feat_sp_dir}/feats.scp --bpecode ${bpemodel}.model \
data/${train_sp} ${dict} > ${feat_sp_dir}/data_${bpemode}${nbpe}.json
data2json.sh --nj ${nj} --feat ${feat_dt_dir}/feats.scp --bpecode ${bpemodel}.model \
data/${train_dev} ${dict} > ${feat_dt_dir}/data_${bpemode}${nbpe}.json
for rtask in ${recog_set}; do
feat_recog_dir=${dumpdir}/${rtask}/delta${do_delta}
data2json.sh --nj ${nj} --feat ${feat_recog_dir}/feats.scp --bpecode ${bpemodel}.model \
data/${rtask} ${dict} > ${feat_recog_dir}/data_${bpemode}${nbpe}.json
done
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# format manifest with tokenids, vocab size # format manifest with tokenids, vocab size
for set in train dev test dev-clean dev-other test-clean test-other; do for set in train dev test dev-clean dev-other test-clean test-other; do

@ -0,0 +1,85 @@
#!/usr/bin/env bash
# Copyright 2014 Vassil Panayotov
# 2014 Johns Hopkins University (author: Daniel Povey)
# Apache 2.0
if [ "$#" -ne 2 ]; then
echo "Usage: $0 <src-dir> <dst-dir>"
echo "e.g.: $0 /export/a15/vpanayotov/data/LibriSpeech/dev-clean data/dev-clean"
exit 1
fi
src=$1
dst=$2
# all utterances are FLAC compressed
if ! which flac >&/dev/null; then
echo "Please install 'flac' on ALL worker nodes!"
exit 1
fi
spk_file=$src/../SPEAKERS.TXT
mkdir -p $dst || exit 1
[ ! -d $src ] && echo "$0: no such directory $src" && exit 1
[ ! -f $spk_file ] && echo "$0: expected file $spk_file to exist" && exit 1
wav_scp=$dst/wav.scp; [[ -f "$wav_scp" ]] && rm $wav_scp
trans=$dst/text; [[ -f "$trans" ]] && rm $trans
utt2spk=$dst/utt2spk; [[ -f "$utt2spk" ]] && rm $utt2spk
spk2gender=$dst/spk2gender; [[ -f $spk2gender ]] && rm $spk2gender
for reader_dir in $(find -L $src -mindepth 1 -maxdepth 1 -type d | sort); do
reader=$(basename $reader_dir)
if ! [ $reader -eq $reader ]; then # not integer.
echo "$0: unexpected subdirectory name $reader"
exit 1
fi
reader_gender=$(egrep "^$reader[ ]+\|" $spk_file | awk -F'|' '{gsub(/[ ]+/, ""); print tolower($2)}')
if [ "$reader_gender" != 'm' ] && [ "$reader_gender" != 'f' ]; then
echo "Unexpected gender: '$reader_gender'"
exit 1
fi
for chapter_dir in $(find -L $reader_dir/ -mindepth 1 -maxdepth 1 -type d | sort); do
chapter=$(basename $chapter_dir)
if ! [ "$chapter" -eq "$chapter" ]; then
echo "$0: unexpected chapter-subdirectory name $chapter"
exit 1
fi
find -L $chapter_dir/ -iname "*.flac" | sort | xargs -I% basename % .flac | \
awk -v "dir=$chapter_dir" '{printf "%s flac -c -d -s %s/%s.flac |\n", $0, dir, $0}' >>$wav_scp|| exit 1
chapter_trans=$chapter_dir/${reader}-${chapter}.trans.txt
[ ! -f $chapter_trans ] && echo "$0: expected file $chapter_trans to exist" && exit 1
cat $chapter_trans >>$trans
# NOTE: For now we are using per-chapter utt2spk. That is each chapter is considered
# to be a different speaker. This is done for simplicity and because we want
# e.g. the CMVN to be calculated per-chapter
awk -v "reader=$reader" -v "chapter=$chapter" '{printf "%s %s-%s\n", $1, reader, chapter}' \
<$chapter_trans >>$utt2spk || exit 1
# reader -> gender map (again using per-chapter granularity)
echo "${reader}-${chapter} $reader_gender" >>$spk2gender
done
done
spk2utt=$dst/spk2utt
utils/utt2spk_to_spk2utt.pl <$utt2spk >$spk2utt || exit 1
ntrans=$(wc -l <$trans)
nutt2spk=$(wc -l <$utt2spk)
! [ "$ntrans" -eq "$nutt2spk" ] && \
echo "Inconsistent #transcripts($ntrans) and #utt2spk($nutt2spk)" && exit 1
utils/validate_data_dir.sh --no-feats $dst || exit 1
echo "$0: successfully prepared data in $dst"
exit 0

@ -0,0 +1 @@
../../../tools/kaldi/egs/wsj/s5/steps/

@ -1 +1 @@
../../../utils/ ../../../tools/kaldi/egs/wsj/s5/utils

@ -48,6 +48,9 @@ mfa.done:
tar xvf montreal-forced-aligner_linux.tar.gz tar xvf montreal-forced-aligner_linux.tar.gz
touch mfa.done touch mfa.done
kaldi.done:
test -d kaldi || git clone --depth 1 https://github.com/kaldi-asr/kaldi
touch kaldi.done
#== SCTK =============================================================================== #== SCTK ===============================================================================
# SCTK official repo does not have version tags. Here's the mapping: # SCTK official repo does not have version tags. Here's the mapping:

@ -0,0 +1,156 @@
#!/usr/bin/env python3
import argparse
from distutils.util import strtobool
import logging
import kaldiio
import numpy
from deepspeech.transform.cmvn import CMVN
from deepspeech.utils.cli_readers import file_reader_helper
from deepspeech.utils.cli_utils import get_commandline_args
from deepspeech.utils.cli_utils import is_scipy_wav_style
from deepspeech.utils.cli_writers import file_writer_helper
def get_parser():
parser = argparse.ArgumentParser(
description="apply mean-variance normalization to files",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
parser.add_argument(
"--in-filetype",
type=str,
default="mat",
choices=["mat", "hdf5", "sound.hdf5", "sound"],
help="Specify the file format for the rspecifier. "
'"mat" is the matrix format in kaldi',
)
parser.add_argument(
"--stats-filetype",
type=str,
default="mat",
choices=["mat", "hdf5", "npy"],
help="Specify the file format for the rspecifier. "
'"mat" is the matrix format in kaldi',
)
parser.add_argument(
"--out-filetype",
type=str,
default="mat",
choices=["mat", "hdf5"],
help="Specify the file format for the wspecifier. "
'"mat" is the matrix format in kaldi',
)
parser.add_argument(
"--norm-means",
type=strtobool,
default=True,
help="Do variance normalization or not.",
)
parser.add_argument(
"--norm-vars",
type=strtobool,
default=False,
help="Do variance normalization or not.",
)
parser.add_argument(
"--reverse", type=strtobool, default=False, help="Do reverse mode or not"
)
parser.add_argument(
"--spk2utt",
type=str,
help="A text file of speaker to utterance-list map. "
"(Don't give rspecifier format, such as "
'"ark:spk2utt")',
)
parser.add_argument(
"--utt2spk",
type=str,
help="A text file of utterance to speaker map. "
"(Don't give rspecifier format, such as "
'"ark:utt2spk")',
)
parser.add_argument(
"--write-num-frames", type=str, help="Specify wspecifer for utt2num_frames"
)
parser.add_argument(
"--compress", type=strtobool, default=False, help="Save in compressed format"
)
parser.add_argument(
"--compression-method",
type=int,
default=2,
help="Specify the method(if mat) or " "gzip-level(if hdf5)",
)
parser.add_argument(
"stats_rspecifier_or_rxfilename",
help="Input stats. e.g. ark:stats.ark or stats.mat",
)
parser.add_argument(
"rspecifier", type=str, help="Read specifier id. e.g. ark:some.ark"
)
parser.add_argument(
"wspecifier", type=str, help="Write specifier id. e.g. ark:some.ark"
)
return parser
def main():
args = get_parser().parse_args()
# logging info
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
if args.verbose > 0:
logging.basicConfig(level=logging.INFO, format=logfmt)
else:
logging.basicConfig(level=logging.WARN, format=logfmt)
logging.info(get_commandline_args())
if ":" in args.stats_rspecifier_or_rxfilename:
is_rspcifier = True
if args.stats_filetype == "npy":
stats_filetype = "hdf5"
else:
stats_filetype = args.stats_filetype
stats_dict = dict(
file_reader_helper(args.stats_rspecifier_or_rxfilename, stats_filetype)
)
else:
is_rspcifier = False
if args.stats_filetype == "mat":
stats = kaldiio.load_mat(args.stats_rspecifier_or_rxfilename)
else:
stats = numpy.load(args.stats_rspecifier_or_rxfilename)
stats_dict = {None: stats}
cmvn = CMVN(
stats=stats_dict,
norm_means=args.norm_means,
norm_vars=args.norm_vars,
utt2spk=args.utt2spk,
spk2utt=args.spk2utt,
reverse=args.reverse,
)
with file_writer_helper(
args.wspecifier,
filetype=args.out_filetype,
write_num_frames=args.write_num_frames,
compress=args.compress,
compression_method=args.compression_method,
) as writer:
for utt, mat in file_reader_helper(args.rspecifier, args.in_filetype):
if is_scipy_wav_style(mat):
# If data is sound file, then got as Tuple[int, ndarray]
rate, mat = mat
mat = cmvn(mat, utt if is_rspcifier else None)
writer[utt] = mat
if __name__ == "__main__":
main()

@ -0,0 +1,68 @@
#!/usr/bin/env python3
# encoding: utf-8
# Copyright 2021 Kyoto University (Hirofumi Inaguma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import codecs
from dateutil import parser
import glob
import os
def get_parser():
parser = argparse.ArgumentParser(description="calculate real time factor (RTF)")
parser.add_argument(
"--log-dir",
type=str,
default=None,
help="path to logging directory",
)
return parser
def main():
args = get_parser().parse_args()
audio_sec = 0
decode_sec = 0
n_utt = 0
audio_durations = []
start_times = []
end_times = []
for x in glob.glob(os.path.join(args.log_dir, "decode.*.log")):
with codecs.open(x, "r", "utf-8") as f:
for line in f:
x = line.strip()
if "INFO: input lengths" in x:
audio_durations += [int(x.split("input lengths: ")[1])]
start_times += [parser.parse(x.split("(")[0])]
elif "INFO: prediction" in x:
end_times += [parser.parse(x.split("(")[0])]
assert len(audio_durations) == len(end_times), (
len(audio_durations),
len(end_times),
)
assert len(start_times) == len(end_times), (len(start_times), len(end_times))
audio_sec += sum(audio_durations) / 100 # [sec]
decode_sec += sum(
[
(end - start).total_seconds()
for start, end in zip(start_times, end_times)
]
)
n_utt += len(audio_durations)
print("Total audio duration: %.3f [sec]" % audio_sec)
print("Total decoding time: %.3f [sec]" % decode_sec)
rtf = decode_sec / audio_sec if audio_sec > 0 else 0
print("RTF: %.3f" % rtf)
latency = decode_sec * 1000 / n_utt if n_utt > 0 else 0
print("Latency: %.3f [ms/sentence]" % latency)
if __name__ == "__main__":
main()

@ -0,0 +1,194 @@
#!/usr/bin/env python3
import argparse
import logging
import kaldiio
import numpy as np
from deepspeech.transform.transformation import Transformation
from deepspeech.utils.cli_readers import file_reader_helper
from deepspeech.utils.cli_utils import get_commandline_args
from deepspeech.utils.cli_utils import is_scipy_wav_style
from deepspeech.utils.cli_writers import file_writer_helper
def get_parser():
parser = argparse.ArgumentParser(
description="Compute cepstral mean and "
"variance normalization statistics"
"If wspecifier provided: per-utterance by default, "
"or per-speaker if"
"spk2utt option provided; if wxfilename: global",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--spk2utt",
type=str,
help="A text file of speaker to utterance-list map. "
"(Don't give rspecifier format, such as "
'"ark:utt2spk")',
)
parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
parser.add_argument(
"--in-filetype",
type=str,
default="mat",
choices=["mat", "hdf5", "sound.hdf5", "sound"],
help="Specify the file format for the rspecifier. "
'"mat" is the matrix format in kaldi',
)
parser.add_argument(
"--out-filetype",
type=str,
default="mat",
choices=["mat", "hdf5", "npy"],
help="Specify the file format for the wspecifier. "
'"mat" is the matrix format in kaldi',
)
parser.add_argument(
"--preprocess-conf",
type=str,
default=None,
help="The configuration file for the pre-processing",
)
parser.add_argument(
"rspecifier", type=str, help="Read specifier for feats. e.g. ark:some.ark"
)
parser.add_argument(
"wspecifier_or_wxfilename", type=str, help="Write specifier. e.g. ark:some.ark"
)
return parser
def main():
args = get_parser().parse_args()
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
if args.verbose > 0:
logging.basicConfig(level=logging.INFO, format=logfmt)
else:
logging.basicConfig(level=logging.WARN, format=logfmt)
logging.info(get_commandline_args())
is_wspecifier = ":" in args.wspecifier_or_wxfilename
if is_wspecifier:
if args.spk2utt is not None:
logging.info("Performing as speaker CMVN mode")
utt2spk_dict = {}
with open(args.spk2utt) as f:
for line in f:
spk, utts = line.rstrip().split(None, 1)
for utt in utts.split():
utt2spk_dict[utt] = spk
def utt2spk(x):
return utt2spk_dict[x]
else:
logging.info("Performing as utterance CMVN mode")
def utt2spk(x):
return x
if args.out_filetype == "npy":
logging.warning(
"--out-filetype npy is allowed only for "
"Global CMVN mode, changing to hdf5"
)
args.out_filetype = "hdf5"
else:
logging.info("Performing as global CMVN mode")
if args.spk2utt is not None:
logging.warning("spk2utt is not used for global CMVN mode")
def utt2spk(x):
return None
if args.out_filetype == "hdf5":
logging.warning(
"--out-filetype hdf5 is not allowed for "
"Global CMVN mode, changing to npy"
)
args.out_filetype = "npy"
if args.preprocess_conf is not None:
preprocessing = Transformation(args.preprocess_conf)
logging.info("Apply preprocessing: {}".format(preprocessing))
else:
preprocessing = None
# Calculate stats for each speaker
counts = {}
sum_feats = {}
square_sum_feats = {}
idx = 0
for idx, (utt, matrix) in enumerate(
file_reader_helper(args.rspecifier, args.in_filetype), 1
):
if is_scipy_wav_style(matrix):
# If data is sound file, then got as Tuple[int, ndarray]
rate, matrix = matrix
if preprocessing is not None:
matrix = preprocessing(matrix, uttid_list=utt)
spk = utt2spk(utt)
# Init at the first seen of the spk
if spk not in counts:
counts[spk] = 0
feat_shape = matrix.shape[1:]
# Accumulate in double precision
sum_feats[spk] = np.zeros(feat_shape, dtype=np.float64)
square_sum_feats[spk] = np.zeros(feat_shape, dtype=np.float64)
counts[spk] += matrix.shape[0]
sum_feats[spk] += matrix.sum(axis=0)
square_sum_feats[spk] += (matrix ** 2).sum(axis=0)
logging.info("Processed {} utterances".format(idx))
assert idx > 0, idx
cmvn_stats = {}
for spk in counts:
feat_shape = sum_feats[spk].shape
cmvn_shape = (2, feat_shape[0] + 1) + feat_shape[1:]
_cmvn_stats = np.empty(cmvn_shape, dtype=np.float64)
_cmvn_stats[0, :-1] = sum_feats[spk]
_cmvn_stats[1, :-1] = square_sum_feats[spk]
_cmvn_stats[0, -1] = counts[spk]
_cmvn_stats[1, -1] = 0.0
# You can get the mean and std as following,
# >>> N = _cmvn_stats[0, -1]
# >>> mean = _cmvn_stats[0, :-1] / N
# >>> std = np.sqrt(_cmvn_stats[1, :-1] / N - mean ** 2)
cmvn_stats[spk] = _cmvn_stats
# Per utterance or speaker CMVN
if is_wspecifier:
with file_writer_helper(
args.wspecifier_or_wxfilename, filetype=args.out_filetype
) as writer:
for spk, mat in cmvn_stats.items():
writer[spk] = mat
# Global CMVN
else:
matrix = cmvn_stats[None]
if args.out_filetype == "npy":
np.save(args.wspecifier_or_wxfilename, matrix)
elif args.out_filetype == "mat":
# Kaldi supports only matrix or vector
kaldiio.save_mat(args.wspecifier_or_wxfilename, matrix)
else:
raise RuntimeError(
"Not supporting: --out-filetype {}".format(args.out_filetype)
)
if __name__ == "__main__":
main()

@ -0,0 +1,105 @@
#!/usr/bin/env python3
import argparse
from distutils.util import strtobool
import logging
from deepspeech.transform.transformation import Transformation
from deepspeech.utils.cli_readers import file_reader_helper
from deepspeech.utils.cli_utils import get_commandline_args
from deepspeech.utils.cli_utils import is_scipy_wav_style
from deepspeech.utils.cli_writers import file_writer_helper
def get_parser():
parser = argparse.ArgumentParser(
description="copy feature with preprocessing",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option")
parser.add_argument(
"--in-filetype",
type=str,
default="mat",
choices=["mat", "hdf5", "sound.hdf5", "sound"],
help="Specify the file format for the rspecifier. "
'"mat" is the matrix format in kaldi',
)
parser.add_argument(
"--out-filetype",
type=str,
default="mat",
choices=["mat", "hdf5", "sound.hdf5", "sound"],
help="Specify the file format for the wspecifier. "
'"mat" is the matrix format in kaldi',
)
parser.add_argument(
"--write-num-frames", type=str, help="Specify wspecifer for utt2num_frames"
)
parser.add_argument(
"--compress", type=strtobool, default=False, help="Save in compressed format"
)
parser.add_argument(
"--compression-method",
type=int,
default=2,
help="Specify the method(if mat) or " "gzip-level(if hdf5)",
)
parser.add_argument(
"--preprocess-conf",
type=str,
default=None,
help="The configuration file for the pre-processing",
)
parser.add_argument(
"rspecifier", type=str, help="Read specifier for feats. e.g. ark:some.ark"
)
parser.add_argument(
"wspecifier", type=str, help="Write specifier. e.g. ark:some.ark"
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
# logging info
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
if args.verbose > 0:
logging.basicConfig(level=logging.INFO, format=logfmt)
else:
logging.basicConfig(level=logging.WARN, format=logfmt)
logging.info(get_commandline_args())
if args.preprocess_conf is not None:
preprocessing = Transformation(args.preprocess_conf)
logging.info("Apply preprocessing: {}".format(preprocessing))
else:
preprocessing = None
with file_writer_helper(
args.wspecifier,
filetype=args.out_filetype,
write_num_frames=args.write_num_frames,
compress=args.compress,
compression_method=args.compression_method,
) as writer:
for utt, mat in file_reader_helper(args.rspecifier, args.in_filetype):
if is_scipy_wav_style(mat):
# If data is sound file, then got as Tuple[int, ndarray]
rate, mat = mat
if preprocessing is not None:
mat = preprocessing(mat, uttid_list=utt)
# shape = (Time, Channel)
if args.out_filetype in ["sound.hdf5", "sound"]:
# Write Tuple[int, numpy.ndarray] (scipy style)
writer[utt] = (rate, mat)
else:
writer[utt] = mat
if __name__ == "__main__":
main()

@ -0,0 +1,170 @@
#!/usr/bin/env bash
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
echo "$0 $*" >&2 # Print the command line for logging
. ./path.sh
nj=1
cmd=run.pl
nlsyms=""
lang=""
feat="" # feat.scp
oov="<unk>"
bpecode=""
allow_one_column=false
verbose=0
trans_type=char
filetype=""
preprocess_conf=""
category=""
out="" # If omitted, write in stdout
text=""
multilingual=false
help_message=$(cat << EOF
Usage: $0 <data-dir> <dict>
e.g. $0 data/train data/lang_1char/train_units.txt
Options:
--nj <nj> # number of parallel jobs
--cmd (utils/run.pl|utils/queue.pl <queue opts>) # how to run jobs.
--feat <feat-scp> # feat.scp or feat1.scp,feat2.scp,...
--oov <oov-word> # Default: <unk>
--out <outputfile> # If omitted, write in stdout
--filetype <mat|hdf5|sound.hdf5> # Specify the format of feats file
--preprocess-conf <json> # Apply preprocess to feats when creating shape.scp
--verbose <num> # Default: 0
EOF
)
. utils/parse_options.sh
if [ $# != 2 ]; then
echo "${help_message}" 1>&2
exit 1;
fi
set -euo pipefail
dir=$1
dic=$2
tmpdir=$(mktemp -d ${dir}/tmp-XXXXX)
trap 'rm -rf ${tmpdir}' EXIT
if [ -z ${text} ]; then
text=${dir}/text
fi
# 1. Create scp files for inputs
# These are not necessary for decoding mode, and make it as an option
input=
if [ -n "${feat}" ]; then
_feat_scps=$(echo "${feat}" | tr ',' ' ' )
read -r -a feat_scps <<< $_feat_scps
num_feats=${#feat_scps[@]}
for (( i=1; i<=num_feats; i++ )); do
feat=${feat_scps[$((i-1))]}
mkdir -p ${tmpdir}/input_${i}
input+="input_${i} "
cat ${feat} > ${tmpdir}/input_${i}/feat.scp
# Dump in the "legacy" style JSON format
if [ -n "${filetype}" ]; then
awk -v filetype=${filetype} '{print $1 " " filetype}' ${feat} \
> ${tmpdir}/input_${i}/filetype.scp
fi
feat_to_shape.sh --cmd "${cmd}" --nj ${nj} \
--filetype "${filetype}" \
--preprocess-conf "${preprocess_conf}" \
--verbose ${verbose} ${feat} ${tmpdir}/input_${i}/shape.scp
done
fi
# 2. Create scp files for outputs
mkdir -p ${tmpdir}/output
if [ -n "${bpecode}" ]; then
if [ ${multilingual} = true ]; then
# remove a space before the language ID
paste -d " " <(awk '{print $1}' ${text}) <(cut -f 2- -d" " ${text} \
| spm_encode --model=${bpecode} --output_format=piece | cut -f 2- -d" ") \
> ${tmpdir}/output/token.scp
else
paste -d " " <(awk '{print $1}' ${text}) <(cut -f 2- -d" " ${text} \
| spm_encode --model=${bpecode} --output_format=piece) \
> ${tmpdir}/output/token.scp
fi
elif [ -n "${nlsyms}" ]; then
text2token.py -s 1 -n 1 -l ${nlsyms} ${text} --trans_type ${trans_type} > ${tmpdir}/output/token.scp
else
text2token.py -s 1 -n 1 ${text} --trans_type ${trans_type} > ${tmpdir}/output/token.scp
fi
< ${tmpdir}/output/token.scp utils/sym2int.pl --map-oov ${oov} -f 2- ${dic} > ${tmpdir}/output/tokenid.scp
# +2 comes from CTC blank and EOS
vocsize=$(tail -n 1 ${dic} | awk '{print $2}')
odim=$(echo "$vocsize + 2" | bc)
< ${tmpdir}/output/tokenid.scp awk -v odim=${odim} '{print $1 " " NF-1 "," odim}' > ${tmpdir}/output/shape.scp
cat ${text} > ${tmpdir}/output/text.scp
# 3. Create scp files for the others
mkdir -p ${tmpdir}/other
if [ ${multilingual} == true ]; then
awk '{
n = split($1,S,"[-]");
lang=S[n];
print $1 " " lang
}' ${text} > ${tmpdir}/other/lang.scp
elif [ -n "${lang}" ]; then
awk -v lang=${lang} '{print $1 " " lang}' ${text} > ${tmpdir}/other/lang.scp
fi
if [ -n "${category}" ]; then
awk -v category=${category} '{print $1 " " category}' ${dir}/text \
> ${tmpdir}/other/category.scp
fi
cat ${dir}/utt2spk > ${tmpdir}/other/utt2spk.scp
# 4. Merge scp files into a JSON file
opts=""
if [ -n "${feat}" ]; then
intypes="${input} output other"
else
intypes="output other"
fi
for intype in ${intypes}; do
if [ -z "$(find "${tmpdir}/${intype}" -name "*.scp")" ]; then
continue
fi
if [ ${intype} != other ]; then
opts+="--${intype%_*}-scps "
else
opts+="--scps "
fi
for x in "${tmpdir}/${intype}"/*.scp; do
k=$(basename ${x} .scp)
if [ ${k} = shape ]; then
opts+="shape:${x}:shape "
else
opts+="${k}:${x} "
fi
done
done
if ${allow_one_column}; then
opts+="--allow-one-column true "
else
opts+="--allow-one-column false "
fi
if [ -n "${out}" ]; then
opts+="-O ${out}"
fi
merge_scp2json.py --verbose ${verbose} ${opts}
rm -fr ${tmpdir}

@ -0,0 +1,95 @@
#!/usr/bin/env bash
# Copyright 2017 Nagoya University (Tomoki Hayashi)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
echo "$0 $*" # Print the command line for logging
. ./path.sh
cmd=run.pl
do_delta=false
nj=1
verbose=0
compress=true
write_utt2num_frames=true
filetype='mat' # mat or hdf5
help_message="Usage: $0 <scp> <cmvnark> <logdir> <dumpdir>"
. utils/parse_options.sh
scp=$1
cvmnark=$2
logdir=$3
dumpdir=$4
if [ $# != 4 ]; then
echo "${help_message}"
exit 1;
fi
set -euo pipefail
mkdir -p ${logdir}
mkdir -p ${dumpdir}
dumpdir=$(perl -e '($dir,$pwd)= @ARGV; if($dir!~m:^/:) { $dir = "$pwd/$dir"; } print $dir; ' ${dumpdir} ${PWD})
for n in $(seq ${nj}); do
# the next command does nothing unless $dumpdir/storage/ exists, see
# utils/create_data_link.pl for more info.
utils/create_data_link.pl ${dumpdir}/feats.${n}.ark
done
if ${write_utt2num_frames}; then
write_num_frames_opt="--write-num-frames=ark,t:$dumpdir/utt2num_frames.JOB"
else
write_num_frames_opt=
fi
# split scp file
split_scps=""
for n in $(seq ${nj}); do
split_scps="$split_scps $logdir/feats.$n.scp"
done
utils/split_scp.pl ${scp} ${split_scps} || exit 1;
# dump features
if ${do_delta}; then
${cmd} JOB=1:${nj} ${logdir}/dump_feature.JOB.log \
apply-cmvn --norm-vars=true ${cvmnark} scp:${logdir}/feats.JOB.scp ark:- \| \
add-deltas ark:- ark:- \| \
copy-feats.py --verbose ${verbose} --out-filetype ${filetype} \
--compress=${compress} --compression-method=2 ${write_num_frames_opt} \
ark:- ark,scp:${dumpdir}/feats.JOB.ark,${dumpdir}/feats.JOB.scp \
|| exit 1
else
${cmd} JOB=1:${nj} ${logdir}/dump_feature.JOB.log \
apply-cmvn --norm-vars=true ${cvmnark} scp:${logdir}/feats.JOB.scp ark:- \| \
copy-feats.py --verbose ${verbose} --out-filetype ${filetype} \
--compress=${compress} --compression-method=2 ${write_num_frames_opt} \
ark:- ark,scp:${dumpdir}/feats.JOB.ark,${dumpdir}/feats.JOB.scp \
|| exit 1
fi
# concatenate scp files
for n in $(seq ${nj}); do
cat ${dumpdir}/feats.${n}.scp || exit 1;
done > ${dumpdir}/feats.scp || exit 1
if ${write_utt2num_frames}; then
for n in $(seq ${nj}); do
cat ${dumpdir}/utt2num_frames.${n} || exit 1;
done > ${dumpdir}/utt2num_frames || exit 1
rm ${dumpdir}/utt2num_frames.* 2>/dev/null
fi
# Write the filetype, this will be used for data2json.sh
echo ${filetype} > ${dumpdir}/filetype
# remove temp scps
rm ${logdir}/feats.*.scp 2>/dev/null
if [ ${verbose} -eq 1 ]; then
echo "Succeeded dumping features for training"
fi

@ -0,0 +1,72 @@
#!/usr/bin/env bash
# Begin configuration section.
nj=4
cmd=run.pl
verbose=0
filetype=""
preprocess_conf=""
# End configuration section.
help_message=$(cat << EOF
Usage: $0 [options] <input-scp> <output-scp> [<log-dir>]
e.g.: $0 data/train/feats.scp data/train/shape.scp data/train/log
Options:
--nj <nj> # number of parallel jobs
--cmd (utils/run.pl|utils/queue.pl <queue opts>) # how to run jobs.
--filetype <mat|hdf5|sound.hdf5> # Specify the format of feats file
--preprocess-conf <json> # Apply preprocess to feats when creating shape.scp
--verbose <num> # Default: 0
EOF
)
echo "$0 $*" 1>&2 # Print the command line for logging
. parse_options.sh || exit 1;
if [ $# -lt 2 ] || [ $# -gt 3 ]; then
echo "${help_message}" 1>&2
exit 1;
fi
set -euo pipefail
scp=$1
outscp=$2
data=$(dirname ${scp})
if [ $# -eq 3 ]; then
logdir=$3
else
logdir=${data}/log
fi
mkdir -p ${logdir}
nj=$((nj<$(<"${scp}" wc -l)?nj:$(<"${scp}" wc -l)))
split_scps=""
for n in $(seq ${nj}); do
split_scps="${split_scps} ${logdir}/feats.${n}.scp"
done
utils/split_scp.pl ${scp} ${split_scps}
if [ -n "${preprocess_conf}" ]; then
preprocess_opt="--preprocess-conf ${preprocess_conf}"
else
preprocess_opt=""
fi
if [ -n "${filetype}" ]; then
filetype_opt="--filetype ${filetype}"
else
filetype_opt=""
fi
${cmd} JOB=1:${nj} ${logdir}/feat_to_shape.JOB.log \
feat-to-shape.py --verbose ${verbose} ${preprocess_opt} ${filetype_opt} \
scp:${logdir}/feats.JOB.scp ${logdir}/shape.JOB.scp
# concatenate the .scp files together.
for n in $(seq ${nj}); do
cat ${logdir}/shape.${n}.scp
done > ${outscp}
rm -f ${logdir}/feats.*.scp 2>/dev/null

@ -0,0 +1,303 @@
#!/usr/bin/env python3
# encoding: utf-8
import argparse
import codecs
from distutils.util import strtobool
from io import open
import json
import logging
import sys
from deepspeech.utils.cli_utils import get_commandline_args
PY2 = sys.version_info[0] == 2
sys.stdin = codecs.getreader("utf-8")(sys.stdin if PY2 else sys.stdin.buffer)
sys.stdout = codecs.getwriter("utf-8")(sys.stdout if PY2 else sys.stdout.buffer)
# Special types:
def shape(x):
"""Change str to List[int]
>>> shape('3,5')
[3, 5]
>>> shape(' [3, 5] ')
[3, 5]
"""
# x: ' [3, 5] ' -> '3, 5'
x = x.strip()
if x[0] == "[":
x = x[1:]
if x[-1] == "]":
x = x[:-1]
return list(map(int, x.split(",")))
def get_parser():
parser = argparse.ArgumentParser(
description="Given each file paths with such format as "
"<key>:<file>:<type>. type> can be omitted and the default "
'is "str". e.g. {} '
"--input-scps feat:data/feats.scp shape:data/utt2feat_shape:shape "
"--input-scps feat:data/feats2.scp shape:data/utt2feat2_shape:shape "
"--output-scps text:data/text shape:data/utt2text_shape:shape "
"--scps utt2spk:data/utt2spk".format(sys.argv[0]),
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--input-scps",
type=str,
nargs="*",
action="append",
default=[],
help="Json files for the inputs",
)
parser.add_argument(
"--output-scps",
type=str,
nargs="*",
action="append",
default=[],
help="Json files for the outputs",
)
parser.add_argument(
"--scps",
type=str,
nargs="+",
default=[],
help="The json files except for the input and outputs",
)
parser.add_argument("--verbose", "-V", default=1, type=int, help="Verbose option")
parser.add_argument(
"--allow-one-column",
type=strtobool,
default=False,
help="Allow one column in input scp files. "
"In this case, the value will be empty string.",
)
parser.add_argument(
"--out",
"-O",
type=str,
help="The output filename. " "If omitted, then output to sys.stdout",
)
return parser
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
args.scps = [args.scps]
# logging info
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
if args.verbose > 0:
logging.basicConfig(level=logging.INFO, format=logfmt)
else:
logging.basicConfig(level=logging.WARN, format=logfmt)
logging.info(get_commandline_args())
# List[List[Tuple[str, str, Callable[[str], Any], str, str]]]
input_infos = []
output_infos = []
infos = []
for lis_list, key_scps_list in [
(input_infos, args.input_scps),
(output_infos, args.output_scps),
(infos, args.scps),
]:
for key_scps in key_scps_list:
lis = []
for key_scp in key_scps:
sps = key_scp.split(":")
if len(sps) == 2:
key, scp = sps
type_func = None
type_func_str = "none"
elif len(sps) == 3:
key, scp, type_func_str = sps
fail = False
try:
# type_func: Callable[[str], Any]
# e.g. type_func_str = "int" -> type_func = int
type_func = eval(type_func_str)
except Exception:
raise RuntimeError("Unknown type: {}".format(type_func_str))
if not callable(type_func):
raise RuntimeError("Unknown type: {}".format(type_func_str))
else:
raise RuntimeError(
"Format <key>:<filepath> "
"or <key>:<filepath>:<type> "
"e.g. feat:data/feat.scp "
"or shape:data/feat.scp:shape: {}".format(key_scp)
)
for item in lis:
if key == item[0]:
raise RuntimeError(
'The key "{}" is duplicated: {} {}'.format(
key, item[3], key_scp
)
)
lis.append((key, scp, type_func, key_scp, type_func_str))
lis_list.append(lis)
# Open scp files
input_fscps = [
[open(i[1], "r", encoding="utf-8") for i in il] for il in input_infos
]
output_fscps = [
[open(i[1], "r", encoding="utf-8") for i in il] for il in output_infos
]
fscps = [[open(i[1], "r", encoding="utf-8") for i in il] for il in infos]
# Note(kamo): What is done here?
# The final goal is creating a JSON file such as.
# {
# "utts": {
# "sample_id1": {(omitted)},
# "sample_id2": {(omitted)},
# ....
# }
# }
#
# To reduce memory usage, reading the input text files for each lines
# and writing JSON elements per samples.
if args.out is None:
out = sys.stdout
else:
out = open(args.out, "w", encoding="utf-8")
out.write('{\n "utts": {\n')
nutt = 0
while True:
nutt += 1
# List[List[str]]
input_lines = [[f.readline() for f in fl] for fl in input_fscps]
output_lines = [[f.readline() for f in fl] for fl in output_fscps]
lines = [[f.readline() for f in fl] for fl in fscps]
# Get the first line
concat = sum(input_lines + output_lines + lines, [])
if len(concat) == 0:
break
first = concat[0]
# Sanity check: Must be sorted by the first column and have same keys
count = 0
for ls_list in (input_lines, output_lines, lines):
for ls in ls_list:
for line in ls:
if line == "" or first == "":
if line != first:
concat = sum(input_infos + output_infos + infos, [])
raise RuntimeError(
"The number of lines mismatch "
'between: "{}" and "{}"'.format(
concat[0][1], concat[count][1]
)
)
elif line.split()[0] != first.split()[0]:
concat = sum(input_infos + output_infos + infos, [])
raise RuntimeError(
"The keys are mismatch at {}th line "
'between "{}" and "{}":\n>>> {}\n>>> {}'.format(
nutt,
concat[0][1],
concat[count][1],
first.rstrip(),
line.rstrip(),
)
)
count += 1
# The end of file
if first == "":
if nutt != 1:
out.write("\n")
break
if nutt != 1:
out.write(",\n")
entry = {}
for inout, _lines, _infos in [
("input", input_lines, input_infos),
("output", output_lines, output_infos),
("other", lines, infos),
]:
lis = []
for idx, (line_list, info_list) in enumerate(zip(_lines, _infos), 1):
if inout == "input":
d = {"name": "input{}".format(idx)}
elif inout == "output":
d = {"name": "target{}".format(idx)}
else:
d = {}
# info_list: List[Tuple[str, str, Callable]]
# line_list: List[str]
for line, info in zip(line_list, info_list):
sps = line.split(None, 1)
if len(sps) < 2:
if not args.allow_one_column:
raise RuntimeError(
"Format error {}th line in {}: "
' Expecting "<key> <value>":\n>>> {}'.format(
nutt, info[1], line
)
)
uttid = sps[0]
value = ""
else:
uttid, value = sps
key = info[0]
type_func = info[2]
value = value.rstrip()
if type_func is not None:
try:
# type_func: Callable[[str], Any]
value = type_func(value)
except Exception:
logging.error(
'"{}" is an invalid function '
"for the {} th line in {}: \n>>> {}".format(
info[4], nutt, info[1], line
)
)
raise
d[key] = value
lis.append(d)
if inout != "other":
entry[inout] = lis
else:
# If key == 'other'. only has the first item
entry.update(lis[0])
entry = json.dumps(
entry, indent=4, ensure_ascii=False, sort_keys=True, separators=(",", ": ")
)
# Add indent
indent = " " * 2
entry = ("\n" + indent).join(entry.split("\n"))
uttid = first.split()[0]
out.write(' "{}": {}'.format(uttid, entry))
out.write(" }\n}\n")
logging.info("{} entries in {}".format(nutt, out.name))

@ -0,0 +1,59 @@
#!/usr/bin/env bash
# koried, 10/29/2012
# Reduce a data set based on a list of turn-ids
help_message="usage: $0 srcdir turnlist destdir"
if [ $1 == "--help" ]; then
echo "${help_message}"
exit 0;
fi
if [ $# != 3 ]; then
echo "${help_message}"
exit 1;
fi
srcdir=$1
reclist=$2
destdir=$3
if [ ! -f ${srcdir}/utt2spk ]; then
echo "$0: no such file $srcdir/utt2spk"
exit 1;
fi
function do_filtering {
# assumes the utt2spk and spk2utt files already exist.
[ -f ${srcdir}/feats.scp ] && utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/feats.scp >${destdir}/feats.scp
[ -f ${srcdir}/wav.scp ] && utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/wav.scp >${destdir}/wav.scp
[ -f ${srcdir}/text ] && utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/text >${destdir}/text
[ -f ${srcdir}/utt2num_frames ] && utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/utt2num_frames >${destdir}/utt2num_frames
[ -f ${srcdir}/spk2gender ] && utils/filter_scp.pl ${destdir}/spk2utt <${srcdir}/spk2gender >${destdir}/spk2gender
[ -f ${srcdir}/cmvn.scp ] && utils/filter_scp.pl ${destdir}/spk2utt <${srcdir}/cmvn.scp >${destdir}/cmvn.scp
if [ -f ${srcdir}/segments ]; then
utils/filter_scp.pl ${destdir}/utt2spk <${srcdir}/segments >${destdir}/segments
awk '{print $2;}' ${destdir}/segments | sort | uniq > ${destdir}/reco # recordings.
# The next line would override the command above for wav.scp, which would be incorrect.
[ -f ${srcdir}/wav.scp ] && utils/filter_scp.pl ${destdir}/reco <${srcdir}/wav.scp >${destdir}/wav.scp
[ -f ${srcdir}/reco2file_and_channel ] && \
utils/filter_scp.pl ${destdir}/reco <${srcdir}/reco2file_and_channel >${destdir}/reco2file_and_channel
# Filter the STM file for proper sclite scoring (this will also remove the comments lines)
[ -f ${srcdir}/stm ] && utils/filter_scp.pl ${destdir}/reco < ${srcdir}/stm > ${destdir}/stm
rm ${destdir}/reco
fi
srcutts=$(wc -l < ${srcdir}/utt2spk)
destutts=$(wc -l < ${destdir}/utt2spk)
echo "Reduced #utt from $srcutts to $destutts"
}
mkdir -p ${destdir}
# filter the utt2spk based on the set of recordings
utils/filter_scp.pl ${reclist} < ${srcdir}/utt2spk > ${destdir}/utt2spk
utils/utt2spk_to_spk2utt.pl < ${destdir}/utt2spk > ${destdir}/spk2utt
do_filtering;

@ -0,0 +1,62 @@
#!/usr/bin/env bash
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
. ./path.sh
maxframes=2000
minframes=10
maxchars=200
minchars=0
nlsyms=""
no_feat=false
trans_type=char
help_message="usage: $0 olddatadir newdatadir"
. utils/parse_options.sh || exit 1;
if [ $# != 2 ]; then
echo "${help_message}"
exit 1;
fi
sdir=$1
odir=$2
mkdir -p ${odir}/tmp
if [ ${no_feat} = true ]; then
# for machine translation
cut -d' ' -f 1 ${sdir}/text > ${odir}/tmp/reclist1
else
echo "extract utterances having less than $maxframes or more than $minframes frames"
utils/data/get_utt2num_frames.sh ${sdir}
< ${sdir}/utt2num_frames awk -v maxframes="$maxframes" '{ if ($2 < maxframes) print }' \
| awk -v minframes="$minframes" '{ if ($2 > minframes) print }' \
| awk '{print $1}' > ${odir}/tmp/reclist1
fi
echo "extract utterances having less than $maxchars or more than $minchars characters"
# counting number of chars. Use (NF - 1) instead of NF to exclude the utterance ID column
if [ -z ${nlsyms} ]; then
text2token.py -s 1 -n 1 ${sdir}/text --trans_type ${trans_type} \
| awk -v maxchars="$maxchars" '{ if (NF - 1 < maxchars) print }' \
| awk -v minchars="$minchars" '{ if (NF - 1 > minchars) print }' \
| awk '{print $1}' > ${odir}/tmp/reclist2
else
text2token.py -l ${nlsyms} -s 1 -n 1 ${sdir}/text --trans_type ${trans_type} \
| awk -v maxchars="$maxchars" '{ if (NF - 1 < maxchars) print }' \
| awk -v minchars="$minchars" '{ if (NF - 1 > minchars) print }' \
| awk '{print $1}' > ${odir}/tmp/reclist2
fi
# extract common lines
comm -12 <(sort ${odir}/tmp/reclist1) <(sort ${odir}/tmp/reclist2) > ${odir}/tmp/reclist
reduce_data_dir.sh ${sdir} ${odir}/tmp/reclist ${odir}
utils/fix_data_dir.sh ${odir}
oldnum=$(wc -l ${sdir}/feats.scp | awk '{print $1}')
newnum=$(wc -l ${odir}/feats.scp | awk '{print $1}')
echo "change from $oldnum to $newnum"

@ -0,0 +1,135 @@
#!/usr/bin/env python3
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import codecs
import re
import sys
is_python2 = sys.version_info[0] == 2
def exist_or_not(i, match_pos):
start_pos = None
end_pos = None
for pos in match_pos:
if pos[0] <= i < pos[1]:
start_pos = pos[0]
end_pos = pos[1]
break
return start_pos, end_pos
def get_parser():
parser = argparse.ArgumentParser(
description="convert raw text to tokenized text",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--nchar",
"-n",
default=1,
type=int,
help="number of characters to split, i.e., \
aabb -> a a b b with -n 1 and aa bb with -n 2",
)
parser.add_argument(
"--skip-ncols", "-s", default=0, type=int, help="skip first n columns"
)
parser.add_argument("--space", default="<space>", type=str, help="space symbol")
parser.add_argument(
"--non-lang-syms",
"-l",
default=None,
type=str,
help="list of non-linguistic symobles, e.g., <NOISE> etc.",
)
parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
parser.add_argument(
"--trans_type",
"-t",
type=str,
default="char",
choices=["char", "phn"],
help="""Transcript type. char/phn. e.g., for TIMIT FADG0_SI1279 -
If trans_type is char,
read from SI1279.WRD file -> "bricks are an alternative"
Else if trans_type is phn,
read from SI1279.PHN file -> "sil b r ih sil k s aa r er n aa l
sil t er n ih sil t ih v sil" """,
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
rs = []
if args.non_lang_syms is not None:
with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f:
nls = [x.rstrip() for x in f.readlines()]
rs = [re.compile(re.escape(x)) for x in nls]
if args.text:
f = codecs.open(args.text, encoding="utf-8")
else:
f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer
)
line = f.readline()
n = args.nchar
while line:
x = line.split()
print(" ".join(x[: args.skip_ncols]), end=" ")
a = " ".join(x[args.skip_ncols :])
# get all matched positions
match_pos = []
for r in rs:
i = 0
while i >= 0:
m = r.search(a, i)
if m:
match_pos.append([m.start(), m.end()])
i = m.end()
else:
break
if args.trans_type == "phn":
a = a.split(" ")
else:
if len(match_pos) > 0:
chars = []
i = 0
while i < len(a):
start_pos, end_pos = exist_or_not(i, match_pos)
if start_pos is not None:
chars.append(a[start_pos:end_pos])
i = end_pos
else:
chars.append(a[i])
i += 1
a = chars
a = [a[j : j + n] for j in range(0, len(a), n)]
a_flat = []
for z in a:
a_flat.append("".join(z))
a_chars = [z.replace(" ", args.space) for z in a_flat]
if args.trans_type == "phn":
a_chars = [z.replace("sil", args.space) for z in a_chars]
print(" ".join(a_chars))
line = f.readline()
if __name__ == "__main__":
main()
Loading…
Cancel
Save