# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Modified from espnet(https://github.com/espnet/espnet) # Modified from wenet(https://github.com/wenet-e2e/wenet) from typing import Optional from paddle.io import Dataset from yacs.config import CfgNode from paddlespeech.s2t.frontend.utility import read_manifest from paddlespeech.s2t.utils.log import Log __all__ = ["ManifestDataset", "TransformDataset"] logger = Log(__name__).getlog() class ManifestDataset(Dataset): @classmethod def params(cls, config: Optional[CfgNode]=None) -> CfgNode: default = CfgNode( dict( manifest="", max_input_len=27.0, min_input_len=0.0, max_output_len=float('inf'), min_output_len=0.0, max_output_input_ratio=float('inf'), min_output_input_ratio=0.0, )) if config is not None: config.merge_from_other_cfg(default) return default @classmethod def from_config(cls, config): """Build a ManifestDataset object from a config. Args: config (yacs.config.CfgNode): configs object. Returns: ManifestDataset: dataet object. """ assert 'manifest' in config.data assert config.data.manifest dataset = cls( manifest_path=config.data.manifest, max_input_len=config.data.max_input_len, min_input_len=config.data.min_input_len, max_output_len=config.data.max_output_len, min_output_len=config.data.min_output_len, max_output_input_ratio=config.data.max_output_input_ratio, min_output_input_ratio=config.data.min_output_input_ratio, ) return dataset def __init__(self, manifest_path, max_input_len=float('inf'), min_input_len=0.0, max_output_len=float('inf'), min_output_len=0.0, max_output_input_ratio=float('inf'), min_output_input_ratio=0.0): """Manifest Dataset Args: manifest_path (str): manifest josn file path max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf'). min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0. max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0. min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0. max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0. min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05. """ super().__init__() # read manifest self._manifest = read_manifest( manifest_path=manifest_path, max_input_len=max_input_len, min_input_len=min_input_len, max_output_len=max_output_len, min_output_len=min_output_len, max_output_input_ratio=max_output_input_ratio, min_output_input_ratio=min_output_input_ratio) self._manifest.sort(key=lambda x: x["input"][0]["shape"][0]) def __len__(self): return len(self._manifest) def __getitem__(self, idx): return self._manifest[idx] class TransformDataset(Dataset): """Transform Dataset. Args: data: list object from make_batchset converter: batch function reader: read data """ def __init__(self, data, converter, reader): """Init function.""" super().__init__() self.data = data self.converter = converter self.reader = reader def __len__(self): """Len function.""" return len(self.data) def __getitem__(self, idx): """[] operator.""" return self.converter([self.reader(self.data[idx], return_uttid=True)]) class AudioDataset(Dataset): def __init__(self, data_file, max_length=10240, min_length=0, token_max_length=200, token_min_length=1, batch_type='static', batch_size=1, max_frames_in_batch=0, sort=True, raw_wav=True, stride_ms=10): """Dataset for loading audio data. Attributes:: data_file: input data file Plain text data file, each line contains following 7 fields, which is split by '\t': utt:utt1 feat:tmp/data/file1.wav or feat:tmp/data/fbank.ark:30 feat_shape: 4.95(in seconds) or feat_shape:495,80(495 is in frames) text:i love you token: i l o v e y o u tokenid: int id of this token token_shape: M,N # M is the number of token, N is vocab size max_length: drop utterance which is greater than max_length(10ms), unit 10ms. min_length: drop utterance which is less than min_length(10ms), unit 10ms. token_max_length: drop utterance which is greater than token_max_length, especially when use char unit for english modeling token_min_length: drop utterance which is less than token_max_length batch_type: static or dynamic, see max_frames_in_batch(dynamic) batch_size: number of utterances in a batch, it's for static batch size. max_frames_in_batch: max feature frames in a batch, when batch_type is dynamic, it's for dynamic batch size. Then batch_size is ignored, we will keep filling the batch until the total frames in batch up to max_frames_in_batch. sort: whether to sort all data, so the utterance with the same length could be filled in a same batch. raw_wav: use raw wave or extracted featute. if raw wave is used, dynamic waveform-level augmentation could be used and the feature is extracted by torchaudio. if extracted featute(e.g. by kaldi) is used, only feature-level augmentation such as specaug could be used. """ assert batch_type in ['static', 'dynamic'] # read manifest data = read_manifest(data_file) if sort: data = sorted(data, key=lambda x: x["feat_shape"][0]) if raw_wav: path_suffix = data[0]['feat'].split(':')[0].splitext()[-1] assert path_suffix not in ('.ark', '.scp') # m second to n frame data = list( map(lambda x: (float(x['feat_shape'][0]) * 1000 / stride_ms), data)) self.input_dim = data[0]['feat_shape'][1] self.output_dim = data[0]['token_shape'][1] valid_data = [] for i in range(len(data)): length = data[i]['feat_shape'][0] token_length = data[i]['token_shape'][0] # remove too lang or too short utt for both input and output # to prevent from out of memory if length > max_length or length < min_length: pass elif token_length > token_max_length or token_length < token_min_length: pass else: valid_data.append(data[i]) logger.info(f"raw dataset len: {len(data)}") data = valid_data num_data = len(data) logger.info(f"dataset len after filter: {num_data}") self.minibatch = [] # Dynamic batch size if batch_type == 'dynamic': assert (max_frames_in_batch > 0) self.minibatch.append([]) num_frames_in_batch = 0 for i in range(num_data): length = data[i]['feat_shape'][0] num_frames_in_batch += length if num_frames_in_batch > max_frames_in_batch: self.minibatch.append([]) num_frames_in_batch = length self.minibatch[-1].append(data[i]) # Static batch size else: cur = 0 while cur < num_data: end = min(cur + batch_size, num_data) item = [] for i in range(cur, end): item.append(data[i]) self.minibatch.append(item) cur = end def __len__(self): """number of example(batch)""" return len(self.minibatch) def __getitem__(self, idx): """batch example of idx""" return self.minibatch[idx]