|
|
|
@ -147,3 +147,131 @@ class TransformDataset(Dataset):
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
|
"""[] operator."""
|
|
|
|
|
return self.converter([self.reader(self.data[idx], return_uttid=True)])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AudioDataset(Dataset):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
data_file,
|
|
|
|
|
max_length=10240,
|
|
|
|
|
min_length=0,
|
|
|
|
|
token_max_length=200,
|
|
|
|
|
token_min_length=1,
|
|
|
|
|
batch_type='static',
|
|
|
|
|
batch_size=1,
|
|
|
|
|
max_frames_in_batch=0,
|
|
|
|
|
sort=True,
|
|
|
|
|
raw_wav=True,
|
|
|
|
|
stride_ms=10):
|
|
|
|
|
"""Dataset for loading audio data.
|
|
|
|
|
Attributes::
|
|
|
|
|
data_file: input data file
|
|
|
|
|
Plain text data file, each line contains following 7 fields,
|
|
|
|
|
which is split by '\t':
|
|
|
|
|
utt:utt1
|
|
|
|
|
feat:tmp/data/file1.wav or feat:tmp/data/fbank.ark:30
|
|
|
|
|
feat_shape: 4.95(in seconds) or feat_shape:495,80(495 is in frames)
|
|
|
|
|
text:i love you
|
|
|
|
|
token: i <space> l o v e <space> y o u
|
|
|
|
|
tokenid: int id of this token
|
|
|
|
|
token_shape: M,N # M is the number of token, N is vocab size
|
|
|
|
|
max_length: drop utterance which is greater than max_length(10ms), unit 10ms.
|
|
|
|
|
min_length: drop utterance which is less than min_length(10ms), unit 10ms.
|
|
|
|
|
token_max_length: drop utterance which is greater than token_max_length,
|
|
|
|
|
especially when use char unit for english modeling
|
|
|
|
|
token_min_length: drop utterance which is less than token_max_length
|
|
|
|
|
batch_type: static or dynamic, see max_frames_in_batch(dynamic)
|
|
|
|
|
batch_size: number of utterances in a batch,
|
|
|
|
|
it's for static batch size.
|
|
|
|
|
max_frames_in_batch: max feature frames in a batch,
|
|
|
|
|
when batch_type is dynamic, it's for dynamic batch size.
|
|
|
|
|
Then batch_size is ignored, we will keep filling the
|
|
|
|
|
batch until the total frames in batch up to max_frames_in_batch.
|
|
|
|
|
sort: whether to sort all data, so the utterance with the same
|
|
|
|
|
length could be filled in a same batch.
|
|
|
|
|
raw_wav: use raw wave or extracted featute.
|
|
|
|
|
if raw wave is used, dynamic waveform-level augmentation could be used
|
|
|
|
|
and the feature is extracted by torchaudio.
|
|
|
|
|
if extracted featute(e.g. by kaldi) is used, only feature-level
|
|
|
|
|
augmentation such as specaug could be used.
|
|
|
|
|
"""
|
|
|
|
|
assert batch_type in ['static', 'dynamic']
|
|
|
|
|
# read manifest
|
|
|
|
|
data = read_manifest(data_file)
|
|
|
|
|
if sort:
|
|
|
|
|
data = sorted(data, key=lambda x: x["feat_shape"][0])
|
|
|
|
|
if raw_wav:
|
|
|
|
|
assert data[0]['feat'].split(':')[0].splitext()[-1] not in ('.ark',
|
|
|
|
|
'.scp')
|
|
|
|
|
data = map(lambda x: (float(x['feat_shape'][0]) * 1000 / stride_ms))
|
|
|
|
|
|
|
|
|
|
self.input_dim = data[0]['feat_shape'][1]
|
|
|
|
|
self.output_dim = data[0]['token_shape'][1]
|
|
|
|
|
|
|
|
|
|
# with open(data_file, 'r') as f:
|
|
|
|
|
# for line in f:
|
|
|
|
|
# arr = line.strip().split('\t')
|
|
|
|
|
# if len(arr) != 7:
|
|
|
|
|
# continue
|
|
|
|
|
# key = arr[0].split(':')[1]
|
|
|
|
|
# tokenid = arr[5].split(':')[1]
|
|
|
|
|
# output_dim = int(arr[6].split(':')[1].split(',')[1])
|
|
|
|
|
# if raw_wav:
|
|
|
|
|
# wav_path = ':'.join(arr[1].split(':')[1:])
|
|
|
|
|
# duration = int(float(arr[2].split(':')[1]) * 1000 / 10)
|
|
|
|
|
# data.append((key, wav_path, duration, tokenid))
|
|
|
|
|
# else:
|
|
|
|
|
# feat_ark = ':'.join(arr[1].split(':')[1:])
|
|
|
|
|
# feat_info = arr[2].split(':')[1].split(',')
|
|
|
|
|
# feat_dim = int(feat_info[1].strip())
|
|
|
|
|
# num_frames = int(feat_info[0].strip())
|
|
|
|
|
# data.append((key, feat_ark, num_frames, tokenid))
|
|
|
|
|
# self.input_dim = feat_dim
|
|
|
|
|
# self.output_dim = output_dim
|
|
|
|
|
|
|
|
|
|
valid_data = []
|
|
|
|
|
for i in range(len(data)):
|
|
|
|
|
length = data[i]['feat_shape'][0]
|
|
|
|
|
token_length = data[i]['token_shape'][0]
|
|
|
|
|
# remove too lang or too short utt for both input and output
|
|
|
|
|
# to prevent from out of memory
|
|
|
|
|
if length > max_length or length < min_length:
|
|
|
|
|
# logging.warn('ignore utterance {} feature {}'.format(
|
|
|
|
|
# data[i][0], length))
|
|
|
|
|
pass
|
|
|
|
|
elif token_length > token_max_length or token_length < token_min_length:
|
|
|
|
|
pass
|
|
|
|
|
else:
|
|
|
|
|
valid_data.append(data[i])
|
|
|
|
|
data = valid_data
|
|
|
|
|
|
|
|
|
|
self.minibatch = []
|
|
|
|
|
num_data = len(data)
|
|
|
|
|
# Dynamic batch size
|
|
|
|
|
if batch_type == 'dynamic':
|
|
|
|
|
assert (max_frames_in_batch > 0)
|
|
|
|
|
self.minibatch.append([])
|
|
|
|
|
num_frames_in_batch = 0
|
|
|
|
|
for i in range(num_data):
|
|
|
|
|
length = data[i]['feat_shape'][0]
|
|
|
|
|
num_frames_in_batch += length
|
|
|
|
|
if num_frames_in_batch > max_frames_in_batch:
|
|
|
|
|
self.minibatch.append([])
|
|
|
|
|
num_frames_in_batch = length
|
|
|
|
|
self.minibatch[-1].append(data[i])
|
|
|
|
|
# Static batch size
|
|
|
|
|
else:
|
|
|
|
|
cur = 0
|
|
|
|
|
while cur < num_data:
|
|
|
|
|
end = min(cur + batch_size, num_data)
|
|
|
|
|
item = []
|
|
|
|
|
for i in range(cur, end):
|
|
|
|
|
item.append(data[i])
|
|
|
|
|
self.minibatch.append(item)
|
|
|
|
|
cur = end
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return len(self.minibatch)
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
|
instance = self.minibatch[idx]
|
|
|
|
|
return instance["utt"], instance["feat"], instance["text"]
|
|
|
|
|