diff --git a/paddlespeech/vector/__init__.py b/paddlespeech/vector/__init__.py index 185a92b8..2a3588ec 100644 --- a/paddlespeech/vector/__init__.py +++ b/paddlespeech/vector/__init__.py @@ -11,3 +11,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +""" +__init__ file for sidt package. +""" + +import logging as sidt_logging +import colorlog + +LOG_COLOR_CONFIG = { + 'DEBUG': 'white', + 'INFO': 'white', + 'WARNING': 'yellow', + 'ERROR': 'red', + 'CRITICAL': 'purple', +} + +# 设置全局的logger +colored_formatter = colorlog.ColoredFormatter( + '%(log_color)s [%(levelname)s] [%(asctime)s] [%(filename)s:%(lineno)d] - %(message)s', + datefmt="%Y-%m-%d %H:%M:%S", + log_colors=LOG_COLOR_CONFIG) # 日志输出格式 +_logger = sidt_logging.getLogger("sidt") +handler = colorlog.StreamHandler() +handler.setLevel(sidt_logging.INFO) +handler.setFormatter(colored_formatter) +_logger.addHandler(handler) +_logger.setLevel(sidt_logging.INFO) + +from .trainer.trainer import Trainer +from .dataset.ark_dataset import create_kaldi_ark_dataset +from .dataset.egs_dataset import create_kaldi_egs_dataset diff --git a/paddlespeech/vector/datasets/ark_dataset.py b/paddlespeech/vector/datasets/ark_dataset.py new file mode 100755 index 00000000..7a00e7ba --- /dev/null +++ b/paddlespeech/vector/datasets/ark_dataset.py @@ -0,0 +1,142 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import random +import numpy as np +import kaldi_python_io as k_io +from paddle.io import Dataset +from paddlespeech.vector.utils.data_utils import batch_pad_right +import paddlespeech.vector.utils as utils +from paddlespeech.vector.utils.utils import read_map_file +from paddlespeech.vector import _logger as log + +def ark_collate_fn(batch): + """ + Custom collate function] for kaldi feats dataset + + Args: + min_chunk_size: min chunk size of a utterance + max_chunk_size: max chunk size of a utterance + + Returns: + ark_collate_fn: collate funtion for dataloader + """ + + data = [] + target = [] + for items in batch: + for x, y in zip(items[0], items[1]): + data.append(np.array(x)) + target.append(y) + + data, lengths = batch_pad_right(data) + return np.array(data, dtype=np.float32), \ + np.array(lengths, dtype=np.float32), \ + np.array(target, dtype=np.long).reshape((len(target), 1)) + + +class KaldiArkDataset(Dataset): + """ + Dataset used to load kaldi ark/scp files. + """ + def __init__(self, scp_file, label2utt, min_item_size=1, + max_item_size=1, repeat=50, min_chunk_size=200, + max_chunk_size=400, select_by_speaker=True): + self.scp_file = scp_file + self.scp_reader = None + self.repeat = repeat + self.min_item_size = min_item_size + self.max_item_size = max_item_size + self.min_chunk_size = min_chunk_size + self.max_chunk_size = max_chunk_size + self._collate_fn = ark_collate_fn + self._is_select_by_speaker = select_by_speaker + if utils.is_exist(self.scp_file): + self.scp_reader = k_io.ScriptReader(self.scp_file) + + label2utts, utt2label = read_map_file(label2utt, key_func=int) + self.utt_info = list(label2utts.items()) if self._is_select_by_speaker else list(utt2label.items()) + + @property + def collate_fn(self): + """ + Return a collate funtion. + """ + return self._collate_fn + + def _random_chunk(self, length): + chunk_size = random.randint(self.min_chunk_size, self.max_chunk_size) + if chunk_size >= length: + return 0, length + start = random.randint(0, length - chunk_size) + end = start + chunk_size + + return start, end + + def _select_by_speaker(self, index): + if self.scp_reader is None or not self.utt_info: + return [] + index = index % (len(self.utt_info)) + inputs = [] + labels = [] + item_size = random.randint(self.min_item_size, self.max_item_size) + for loop_idx in range(item_size): + try: + utt_index = random.randint(0, len(self.utt_info[index][1])) \ + % len(self.utt_info[index][1]) + key = self.utt_info[index][1][utt_index] + except: + print(index, utt_index, len(self.utt_info[index][1])) + sys.exit(-1) + x = self.scp_reader[key] + x = np.transpose(x) + bg, end = self._random_chunk(x.shape[-1]) + inputs.append(x[:, bg: end]) + labels.append(self.utt_info[index][0]) + return inputs, labels + + def _select_by_utt(self, index): + if self.scp_reader is None or len(self.utt_info) == 0: + return {} + index = index % (len(self.utt_info)) + key = self.utt_info[index][0] + x = self.scp_reader[key] + x = np.transpose(x) + bg, end = self._random_chunk(x.shape[-1]) + + y = self.utt_info[index][1] + + return [x[:, bg: end]], [y] + + def __getitem__(self, index): + if self._is_select_by_speaker: + return self._select_by_speaker(index) + else: + return self._select_by_utt(index) + + def __len__(self): + return len(self.utt_info) * self.repeat + + def __iter__(self): + self._start = 0 + return self + + def __next__(self): + if self._start < len(self): + ret = self[self._start] + self._start += 1 + return ret + else: + raise StopIteration diff --git a/paddlespeech/vector/datasets/egs_dataset.py b/paddlespeech/vector/datasets/egs_dataset.py new file mode 100644 index 00000000..53130d5f --- /dev/null +++ b/paddlespeech/vector/datasets/egs_dataset.py @@ -0,0 +1,91 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Load nnet3 training egs which generated by kaldi +""" + +import random +import numpy as np +import kaldi_python_io as k_io +from paddle.io import Dataset +import paddlespeech.vector.utils.utils as utils +from paddlespeech.vector import _logger as log +class KaldiEgsDataset(Dataset): + """ + Dataset used to load kaldi nnet3 egs files. + """ + def __init__(self, egs_list_file, egs_idx, transforms=None): + self.scp_reader = None + self.subset_idx = egs_idx - 1 + self.transforms = transforms + if not utils.is_exist(egs_list_file): + return + + self.egs_files = [] + with open(egs_list_file, 'r') as in_fh: + for line in in_fh: + if line.strip(): + self.egs_files.append(line.strip()) + + self.next_subset() + + def next_subset(self, target_index=None, delta_index=None): + """ + Use next specific subset + + Args: + target_index: target egs index + delta_index: incremental value of egs index + """ + if self.egs_files: + if target_index: + self.subset_idx = target_index + else: + delta_index = delta_index if delta_index else 1 + self.subset_idx += delta_index + log.info("egs dataset subset index: %d" % (self.subset_idx)) + egs_file = self.egs_files[self.subset_idx % len(self.egs_files)] + if utils.is_exist(egs_file): + self.scp_reader = k_io.Nnet3EgsScriptReader(egs_file) + else: + log.warning("No such file or directory: %s" % (egs_file)) + + def __getitem__(self, index): + if self.scp_reader is None: + return {} + index %= len(self) + in_dict, out_dict = self.scp_reader[index] + x = np.array(in_dict['matrix']) + x = np.transpose(x) + y = np.array(out_dict['matrix'][0][0][0], dtype=np.int).reshape((1,)) + if self.transforms is not None: + idx = random.randint(0, len(self.transforms) - 1) + x = self.transforms[idx](x) + return x, y + + def __len__(self): + return len(self.scp_reader) + + def __iter__(self): + self._start = 0 + return self + + def __next__(self): + if self._start < len(self): + ret = self[self._start] + self._start += 1 + return ret + else: + raise StopIteration \ No newline at end of file diff --git a/paddlespeech/vector/utils/data_utils.py b/paddlespeech/vector/utils/data_utils.py old mode 100644 new mode 100755 diff --git a/paddlespeech/vector/utils/utils.py b/paddlespeech/vector/utils/utils.py old mode 100644 new mode 100755 index c46e42c2..a28cb526 --- a/paddlespeech/vector/utils/utils.py +++ b/paddlespeech/vector/utils/utils.py @@ -20,7 +20,7 @@ import sys import paddle import numpy as np -from sidt import _logger as log +from paddlespeech.vector import _logger as log def exit_if_not_exist(in_path):