You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
251 lines
9.5 KiB
251 lines
9.5 KiB
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# Modified from espnet(https://github.com/espnet/espnet)
|
|
# Modified from wenet(https://github.com/wenet-e2e/wenet)
|
|
from typing import Optional
|
|
|
|
import jsonlines
|
|
from paddle.io import Dataset
|
|
from yacs.config import CfgNode
|
|
|
|
from paddlespeech.s2t.frontend.utility import read_manifest
|
|
from paddlespeech.s2t.utils.log import Log
|
|
|
|
__all__ = ["ManifestDataset", "TransformDataset"]
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
|
|
class ManifestDataset(Dataset):
|
|
@classmethod
|
|
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
|
|
default = CfgNode(
|
|
dict(
|
|
manifest="",
|
|
max_input_len=27.0,
|
|
min_input_len=0.0,
|
|
max_output_len=float('inf'),
|
|
min_output_len=0.0,
|
|
max_output_input_ratio=float('inf'),
|
|
min_output_input_ratio=0.0, ))
|
|
|
|
if config is not None:
|
|
config.merge_from_other_cfg(default)
|
|
return default
|
|
|
|
@classmethod
|
|
def from_config(cls, config):
|
|
"""Build a ManifestDataset object from a config.
|
|
|
|
Args:
|
|
config (yacs.config.CfgNode): configs object.
|
|
|
|
Returns:
|
|
ManifestDataset: dataet object.
|
|
"""
|
|
assert 'manifest' in config.data
|
|
assert config.data.manifest
|
|
|
|
dataset = cls(
|
|
manifest_path=config.data.manifest,
|
|
max_input_len=config.data.max_input_len,
|
|
min_input_len=config.data.min_input_len,
|
|
max_output_len=config.data.max_output_len,
|
|
min_output_len=config.data.min_output_len,
|
|
max_output_input_ratio=config.data.max_output_input_ratio,
|
|
min_output_input_ratio=config.data.min_output_input_ratio, )
|
|
return dataset
|
|
|
|
def __init__(self,
|
|
manifest_path,
|
|
max_input_len=float('inf'),
|
|
min_input_len=0.0,
|
|
max_output_len=float('inf'),
|
|
min_output_len=0.0,
|
|
max_output_input_ratio=float('inf'),
|
|
min_output_input_ratio=0.0):
|
|
"""Manifest Dataset
|
|
|
|
Args:
|
|
manifest_path (str): manifest josn file path
|
|
max_input_len ([type], optional): maximum output seq length,
|
|
in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf').
|
|
min_input_len (float, optional): minimum input seq length,
|
|
in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0.
|
|
max_output_len (float, optional): maximum input seq length,
|
|
in modeling units. Defaults to 500.0.
|
|
min_output_len (float, optional): minimum input seq length,
|
|
in modeling units. Defaults to 0.0.
|
|
max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio.
|
|
Defaults to 10.0.
|
|
min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio.
|
|
Defaults to 0.05.
|
|
|
|
"""
|
|
super().__init__()
|
|
|
|
# read manifest
|
|
self._manifest = read_manifest(
|
|
manifest_path=manifest_path,
|
|
max_input_len=max_input_len,
|
|
min_input_len=min_input_len,
|
|
max_output_len=max_output_len,
|
|
min_output_len=min_output_len,
|
|
max_output_input_ratio=max_output_input_ratio,
|
|
min_output_input_ratio=min_output_input_ratio)
|
|
self._manifest.sort(key=lambda x: x["input"][0]["shape"][0])
|
|
|
|
def __len__(self):
|
|
return len(self._manifest)
|
|
|
|
def __getitem__(self, idx):
|
|
return self._manifest[idx]
|
|
|
|
|
|
class TransformDataset(Dataset):
|
|
"""Transform Dataset.
|
|
|
|
Args:
|
|
data: list object from make_batchset
|
|
converter: batch function
|
|
reader: read data
|
|
"""
|
|
|
|
def __init__(self, data, converter, reader):
|
|
"""Init function."""
|
|
super().__init__()
|
|
self.data = data
|
|
self.converter = converter
|
|
self.reader = reader
|
|
|
|
def __len__(self):
|
|
"""Len function."""
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
"""[] operator."""
|
|
return self.converter([self.reader(self.data[idx], return_uttid=True)])
|
|
|
|
|
|
class AudioDataset(Dataset):
|
|
def __init__(self,
|
|
data_file,
|
|
max_length=10240,
|
|
min_length=0,
|
|
token_max_length=200,
|
|
token_min_length=1,
|
|
batch_type='static',
|
|
batch_size=1,
|
|
max_frames_in_batch=0,
|
|
sort=True,
|
|
raw_wav=True,
|
|
stride_ms=10):
|
|
"""Dataset for loading audio data.
|
|
Attributes::
|
|
data_file: input data file
|
|
Plain text data file, each line contains following 7 fields,
|
|
which is split by '\t':
|
|
utt:utt1
|
|
feat:tmp/data/file1.wav or feat:tmp/data/fbank.ark:30
|
|
feat_shape: 4.95(in seconds) or feat_shape:495,80(495 is in frames)
|
|
text:i love you
|
|
token: i <space> l o v e <space> y o u
|
|
tokenid: int id of this token
|
|
token_shape: M,N # M is the number of token, N is vocab size
|
|
max_length: drop utterance which is greater than max_length(10ms), unit 10ms.
|
|
min_length: drop utterance which is less than min_length(10ms), unit 10ms.
|
|
token_max_length: drop utterance which is greater than token_max_length,
|
|
especially when use char unit for english modeling
|
|
token_min_length: drop utterance which is less than token_max_length
|
|
batch_type: static or dynamic, see max_frames_in_batch(dynamic)
|
|
batch_size: number of utterances in a batch,
|
|
it's for static batch size.
|
|
max_frames_in_batch: max feature frames in a batch,
|
|
when batch_type is dynamic, it's for dynamic batch size.
|
|
Then batch_size is ignored, we will keep filling the
|
|
batch until the total frames in batch up to max_frames_in_batch.
|
|
sort: whether to sort all data, so the utterance with the same
|
|
length could be filled in a same batch.
|
|
raw_wav: use raw wave or extracted featute.
|
|
if raw wave is used, dynamic waveform-level augmentation could be used
|
|
and the feature is extracted by torchaudio.
|
|
if extracted featute(e.g. by kaldi) is used, only feature-level
|
|
augmentation such as specaug could be used.
|
|
"""
|
|
assert batch_type in ['static', 'dynamic']
|
|
# read manifest
|
|
with jsonlines.open(data_file, 'r') as reader:
|
|
data = list(reader)
|
|
if sort:
|
|
data = sorted(data, key=lambda x: x["feat_shape"][0])
|
|
if raw_wav:
|
|
path_suffix = data[0]['feat'].split(':')[0].splitext()[-1]
|
|
assert path_suffix not in ('.ark', '.scp')
|
|
# m second to n frame
|
|
data = list(
|
|
map(lambda x: (float(x['feat_shape'][0]) * 1000 / stride_ms),
|
|
data))
|
|
|
|
self.input_dim = data[0]['feat_shape'][1]
|
|
self.output_dim = data[0]['token_shape'][1]
|
|
|
|
valid_data = []
|
|
for i in range(len(data)):
|
|
length = data[i]['feat_shape'][0]
|
|
token_length = data[i]['token_shape'][0]
|
|
# remove too lang or too short utt for both input and output
|
|
# to prevent from out of memory
|
|
if length > max_length or length < min_length:
|
|
pass
|
|
elif token_length > token_max_length or token_length < token_min_length:
|
|
pass
|
|
else:
|
|
valid_data.append(data[i])
|
|
logger.info(f"raw dataset len: {len(data)}")
|
|
data = valid_data
|
|
num_data = len(data)
|
|
logger.info(f"dataset len after filter: {num_data}")
|
|
|
|
self.minibatch = []
|
|
# Dynamic batch size
|
|
if batch_type == 'dynamic':
|
|
assert (max_frames_in_batch > 0)
|
|
self.minibatch.append([])
|
|
num_frames_in_batch = 0
|
|
for i in range(num_data):
|
|
length = data[i]['feat_shape'][0]
|
|
num_frames_in_batch += length
|
|
if num_frames_in_batch > max_frames_in_batch:
|
|
self.minibatch.append([])
|
|
num_frames_in_batch = length
|
|
self.minibatch[-1].append(data[i])
|
|
# Static batch size
|
|
else:
|
|
cur = 0
|
|
while cur < num_data:
|
|
end = min(cur + batch_size, num_data)
|
|
item = []
|
|
for i in range(cur, end):
|
|
item.append(data[i])
|
|
self.minibatch.append(item)
|
|
cur = end
|
|
|
|
def __len__(self):
|
|
"""number of example(batch)"""
|
|
return len(self.minibatch)
|
|
|
|
def __getitem__(self, idx):
|
|
"""batch example of idx"""
|
|
return self.minibatch[idx]
|