From 6f7e9656febf8e399ef09b749da7641bee438dc5 Mon Sep 17 00:00:00 2001 From: xiongxinlei Date: Fri, 25 Feb 2022 20:05:25 +0800 Subject: [PATCH] add kaldi feats ark dataset --- paddlespeech/vector/datasets/dataset.py | 143 ++++++++++++++++++++++++ paddlespeech/vector/utils/data_utils.py | 125 +++++++++++++++++++++ paddlespeech/vector/utils/utils.py | 132 ++++++++++++++++++++++ 3 files changed, 400 insertions(+) create mode 100644 paddlespeech/vector/datasets/dataset.py create mode 100644 paddlespeech/vector/utils/data_utils.py create mode 100644 paddlespeech/vector/utils/utils.py diff --git a/paddlespeech/vector/datasets/dataset.py b/paddlespeech/vector/datasets/dataset.py new file mode 100644 index 00000000..e7030053 --- /dev/null +++ b/paddlespeech/vector/datasets/dataset.py @@ -0,0 +1,143 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import random +import numpy as np +import kaldi_python_io as k_io +from paddle.io import Dataset +from paddlespeech.vector.utils.data_utils import batch_pad_right +import paddlespeech.vector.utils as utils +from paddlespeech.vector.utils.utils import read_map_file + +def ark_collate_fn(batch): + """ + Custom collate function for kaldi feats dataset + + Args: + min_chunk_size: min chunk size of a utterance + max_chunk_size: max chunk size of a utterance + + Returns: + ark_collate_fn: collate funtion for dataloader + """ + + data = [] + target = [] + for items in batch: + for x, y in zip(items[0], items[1]): + data.append(np.array(x)) + target.append(y) + + data, lengths = batch_pad_right(data) + return np.array(data, dtype=np.float32), \ + np.array(lengths, dtype=np.float32), \ + np.array(target, dtype=np.long).reshape((len(target), 1)) + + +class KaldiArkDataset(Dataset): + """ + Dataset used to load kaldi ark/scp files. + """ + def __init__(self, scp_file, label2utt, min_item_size=1, + max_item_size=1, repeat=50, min_chunk_size=200, + max_chunk_size=400, select_by_speaker=True): + self.scp_file = scp_file + self.scp_reader = None + self.repeat = repeat + self.min_item_size = min_item_size + self.max_item_size = max_item_size + self.min_chunk_size = min_chunk_size + self.max_chunk_size = max_chunk_size + self._collate_fn = ark_collate_fn + self._is_select_by_speaker = select_by_speaker + if utils.is_exist(self.scp_file): + self.scp_reader = k_io.ScriptReader(self.scp_file) + + label2utts, utt2label = read_map_file(label2utt, key_func=int) + self.utt_info = list(label2utts.items()) if self._is_select_by_speaker else list(utt2label.items()) + + @property + def collate_fn(self): + """ + Return a collate funtion. + """ + return self._collate_fn + + def _random_chunk(self, length): + chunk_size = random.randint(self.min_chunk_size, self.max_chunk_size) + if chunk_size >= length: + return 0, length + start = random.randint(0, length - chunk_size) + end = start + chunk_size + + return start, end + + def _select_by_speaker(self, index): + if self.scp_reader is None or not self.utt_info: + return [] + index = index % (len(self.utt_info)) + inputs = [] + labels = [] + item_size = random.randint(self.min_item_size, self.max_item_size) + for loop_idx in range(item_size): + try: + utt_index = random.randint(0, len(self.utt_info[index][1])) \ + % len(self.utt_info[index][1]) + key = self.utt_info[index][1][utt_index] + except: + print(index, utt_index, len(self.utt_info[index][1])) + sys.exit(-1) + x = self.scp_reader[key] + x = np.transpose(x) + bg, end = self._random_chunk(x.shape[-1]) + inputs.append(x[:, bg: end]) + labels.append(self.utt_info[index][0]) + return inputs, labels + + def _select_by_utt(self, index): + if self.scp_reader is None or len(self.utt_info) == 0: + return {} + index = index % (len(self.utt_info)) + key = self.utt_info[index][0] + x = self.scp_reader[key] + x = np.transpose(x) + bg, end = self._random_chunk(x.shape[-1]) + + y = self.utt_info[index][1] + + return [x[:, bg: end]], [y] + + def __getitem__(self, index): + if self._is_select_by_speaker: + return self._select_by_speaker(index) + else: + return self._select_by_utt(index) + + def __len__(self): + return len(self.utt_info) * self.repeat + + def __iter__(self): + self._start = 0 + return self + + def __next__(self): + if self._start < len(self): + ret = self[self._start] + self._start += 1 + return ret + else: + raise StopIteration + +return KaldiArkDataset diff --git a/paddlespeech/vector/utils/data_utils.py b/paddlespeech/vector/utils/data_utils.py new file mode 100644 index 00000000..4a33a795 --- /dev/null +++ b/paddlespeech/vector/utils/data_utils.py @@ -0,0 +1,125 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +data utilities +""" +import os +import sys +import numpy +import paddle + + +def pad_right_to(array, target_shape, mode="constant", value=0): + """ + This function takes a numpy array of arbitrary shape and pads it to target + shape by appending values on the right. + + Args: + array: input numpy array. Input array whose dimension we need to pad. + target_shape : (list, tuple). Target shape we want for the target array its len must be equal to array.ndim + mode : str. Pad mode, please refer to numpy.pad documentation. + value : float. Pad value, please refer to numpy.pad documentation. + + Returns: + array: numpy.array. Padded array. + valid_vals : list. List containing proportion for each dimension of original, non-padded values. + """ + assert len(target_shape) == array.ndim + pads = [] # this contains the abs length of the padding for each dimension. + valid_vals = [] # thic contains the relative lengths for each dimension. + i = 0 # iterating over target_shape ndims + while i < len(target_shape): + assert ( + target_shape[i] >= array.shape[i] + ), "Target shape must be >= original shape for every dim" + pads.append([0, target_shape[i] - array.shape[i]]) + valid_vals.append(array.shape[i] / target_shape[i]) + i += 1 + + array = numpy.pad(array, pads, mode=mode, constant_values=value) + + return array, valid_vals + + +def batch_pad_right(arrays, mode="constant", value=0): + """Given a list of numpy arrays it batches them together by padding to the right + on each dimension in order to get same length for all. + + Args: + arrays : list. List of array we wish to pad together. + mode : str. Padding mode see numpy.pad documentation. + value : float. Padding value see numpy.pad documentation. + + Returns: + array : numpy.array. Padded array. + valid_vals : list. List containing proportion for each dimension of original, non-padded values. + """ + + if not len(arrays): + raise IndexError("arrays list must not be empty") + + if len(arrays) == 1: + # if there is only one array in the batch we simply unsqueeze it. + return numpy.expand_dims(arrays[0], axis=0), numpy.array([1.0]) + + if not ( + any( + [arrays[i].ndim == arrays[0].ndim for i in range(1, len(arrays))] + ) + ): + raise IndexError("All arrays must have same number of dimensions") + + # FIXME we limit the support here: we allow padding of only the last dimension + # need to remove this when feat extraction is updated to handle multichannel. + max_shape = [] + for dim in range(arrays[0].ndim): + if dim != (arrays[0].ndim - 1): + if not all( + [x.shape[dim] == arrays[0].shape[dim] for x in arrays[1:]] + ): + raise EnvironmentError( + "arrays should have same dimensions except for last one" + ) + max_shape.append(max([x.shape[dim] for x in arrays])) + + batched = [] + valid = [] + for t in arrays: + # for each array we apply pad_right_to + padded, valid_percent = pad_right_to( + t, max_shape, mode=mode, value=value + ) + batched.append(padded) + valid.append(valid_percent[-1]) + + batched = numpy.stack(batched) + + return batched, numpy.array(valid) + + +def length_to_mask(length, max_len=None, dtype=None): + """Creates a binary mask for each sequence. + """ + assert len(length.shape) == 1 + + if max_len is None: + max_len = paddle.cast(paddle.max(length), dtype="int64") # using arange to generate mask + mask = paddle.arange(max_len, dtype=length.dtype).expand([paddle.shape(length)[0], max_len]) < length.unsqueeze(1) + + if dtype is None: + dtype = length.dtype + + mask = paddle.cast(mask, dtype=dtype) + return mask diff --git a/paddlespeech/vector/utils/utils.py b/paddlespeech/vector/utils/utils.py new file mode 100644 index 00000000..c46e42c2 --- /dev/null +++ b/paddlespeech/vector/utils/utils.py @@ -0,0 +1,132 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +utilities +""" +import os +import sys +import paddle +import numpy as np + +from sidt import _logger as log + + +def exit_if_not_exist(in_path): + """ + Check the existence of a file or directory, if not exit, exit the program. + + Args: + in_path: input dicrector + """ + if not is_exist(in_path): + sys.exit(-1) + + +def is_exist(in_path): + """ + Check the existence of a file or directory + + Args: + in_path: input dicrector + + Returns: + True or False + """ + if not os.path.exists(in_path): + log.error("No such file or directory: %s" % (in_path)) + return False + + return True + + +def get_latest_file(target_dir): + """ + Get the latest file in target directory + + Args: + target_dir: target directory + + Returns: + latest_file: a string or None + """ + items = os.listdir(target_dir) + items.sort(key=lambda fn: os.path.getmtime(os.path.join(target_dir, fn)) \ + if not os.path.isdir(os.path.join(target_dir, fn)) else 0) + latest_file = None if not items else os.path.join(target_dir, items[-1]) + return latest_file + + +def avg_models(models): + """ + merge multiple models + """ + checkpoint_dict = paddle.load(models[0]) + final_state_dict = checkpoint_dict + + if len(models) > 1: + for model in models[1:]: + checkpoint_dict = paddle.load(model) + for k, v in checkpoint_dict.items(): + final_state_dict[k] += v + for k in final_state_dict.keys(): + final_state_dict[k] /= float(len(models)) + if np.any(np.isnan(final_state_dict[k])): + print("Nan in %s" % (k)) + + return final_state_dict + +def Q_from_tokens(token_num): + """ + get prior model, data from uniform, would support others(guassian) in future + """ + freq = [1] * token_num + Q = paddle.to_tensor(freq, dtype = 'float64') + return Q / Q.sum() + + +def read_map_file(map_file, key_func=None, value_func=None, values_func=None): + """ Read map file. First colume is key, the rest columes are values. + + Args: + map_file: map file + key_func: convert function for key + value_func: convert function for each value + values_func: convert function for values + + Returns: + dict: key 2 value + dict: value 2 key + """ + if not is_exist(map_file): + sys.exit(0) + + key2val = {} + val2key = {} + with open(map_file, 'r') as f: + for line in f: + line = line.strip() + if not line: + continue + items = line.split() + assert len(items) >= 2 + key = items[0] if not key_func else key_func(items[0]) + values = items[1:] if not value_func else [value_func(item) for item in items[1:]] + if values_func: + values = values_func(values) + key2val[key] = values + for value in values: + val2key[value] = key + + return key2val, val2key