diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 811da39b..6bf01900 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -199,11 +199,11 @@ class U2Trainer(Trainer): report("Rank", dist.get_rank()) report("epoch", self.epoch) report('step', self.iteration) - report('iter', batch_index + 1) - report('total',len(self.train_loader)) report("lr", self.lr_scheduler()) self.train_batch(batch_index, batch, msg) self.after_train_batch() + report('iter', batch_index + 1) + report('total', len(self.train_loader)) report('reader_cost', dataload_time) observation['batch_cost'] = observation[ 'reader_cost'] + observation['step_cost'] diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index df300479..15b89ab9 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -292,10 +292,6 @@ class SpeechCollator(): olens = np.array(text_lens).astype(np.int64) return utts, xs_pad, ilens, ys_pad, olens - @property - def manifest(self): - return self._manifest - @property def vocab_size(self): return self._speech_featurizer.vocab_size diff --git a/deepspeech/io/dataset.py b/deepspeech/io/dataset.py index e58e03b4..56e53475 100644 --- a/deepspeech/io/dataset.py +++ b/deepspeech/io/dataset.py @@ -147,3 +147,131 @@ class TransformDataset(Dataset): def __getitem__(self, idx): """[] operator.""" return self.converter([self.reader(self.data[idx], return_uttid=True)]) + + +class AudioDataset(Dataset): + def __init__(self, + data_file, + max_length=10240, + min_length=0, + token_max_length=200, + token_min_length=1, + batch_type='static', + batch_size=1, + max_frames_in_batch=0, + sort=True, + raw_wav=True, + stride_ms=10): + """Dataset for loading audio data. + Attributes:: + data_file: input data file + Plain text data file, each line contains following 7 fields, + which is split by '\t': + utt:utt1 + feat:tmp/data/file1.wav or feat:tmp/data/fbank.ark:30 + feat_shape: 4.95(in seconds) or feat_shape:495,80(495 is in frames) + text:i love you + token: i l o v e y o u + tokenid: int id of this token + token_shape: M,N # M is the number of token, N is vocab size + max_length: drop utterance which is greater than max_length(10ms), unit 10ms. + min_length: drop utterance which is less than min_length(10ms), unit 10ms. + token_max_length: drop utterance which is greater than token_max_length, + especially when use char unit for english modeling + token_min_length: drop utterance which is less than token_max_length + batch_type: static or dynamic, see max_frames_in_batch(dynamic) + batch_size: number of utterances in a batch, + it's for static batch size. + max_frames_in_batch: max feature frames in a batch, + when batch_type is dynamic, it's for dynamic batch size. + Then batch_size is ignored, we will keep filling the + batch until the total frames in batch up to max_frames_in_batch. + sort: whether to sort all data, so the utterance with the same + length could be filled in a same batch. + raw_wav: use raw wave or extracted featute. + if raw wave is used, dynamic waveform-level augmentation could be used + and the feature is extracted by torchaudio. + if extracted featute(e.g. by kaldi) is used, only feature-level + augmentation such as specaug could be used. + """ + assert batch_type in ['static', 'dynamic'] + # read manifest + data = read_manifest(data_file) + if sort: + data = sorted(data, key=lambda x: x["feat_shape"][0]) + if raw_wav: + assert data[0]['feat'].split(':')[0].splitext()[-1] not in ('.ark', + '.scp') + data = map(lambda x: (float(x['feat_shape'][0]) * 1000 / stride_ms)) + + self.input_dim = data[0]['feat_shape'][1] + self.output_dim = data[0]['token_shape'][1] + + # with open(data_file, 'r') as f: + # for line in f: + # arr = line.strip().split('\t') + # if len(arr) != 7: + # continue + # key = arr[0].split(':')[1] + # tokenid = arr[5].split(':')[1] + # output_dim = int(arr[6].split(':')[1].split(',')[1]) + # if raw_wav: + # wav_path = ':'.join(arr[1].split(':')[1:]) + # duration = int(float(arr[2].split(':')[1]) * 1000 / 10) + # data.append((key, wav_path, duration, tokenid)) + # else: + # feat_ark = ':'.join(arr[1].split(':')[1:]) + # feat_info = arr[2].split(':')[1].split(',') + # feat_dim = int(feat_info[1].strip()) + # num_frames = int(feat_info[0].strip()) + # data.append((key, feat_ark, num_frames, tokenid)) + # self.input_dim = feat_dim + # self.output_dim = output_dim + + valid_data = [] + for i in range(len(data)): + length = data[i]['feat_shape'][0] + token_length = data[i]['token_shape'][0] + # remove too lang or too short utt for both input and output + # to prevent from out of memory + if length > max_length or length < min_length: + # logging.warn('ignore utterance {} feature {}'.format( + # data[i][0], length)) + pass + elif token_length > token_max_length or token_length < token_min_length: + pass + else: + valid_data.append(data[i]) + data = valid_data + + self.minibatch = [] + num_data = len(data) + # Dynamic batch size + if batch_type == 'dynamic': + assert (max_frames_in_batch > 0) + self.minibatch.append([]) + num_frames_in_batch = 0 + for i in range(num_data): + length = data[i]['feat_shape'][0] + num_frames_in_batch += length + if num_frames_in_batch > max_frames_in_batch: + self.minibatch.append([]) + num_frames_in_batch = length + self.minibatch[-1].append(data[i]) + # Static batch size + else: + cur = 0 + while cur < num_data: + end = min(cur + batch_size, num_data) + item = [] + for i in range(cur, end): + item.append(data[i]) + self.minibatch.append(item) + cur = end + + def __len__(self): + return len(self.minibatch) + + def __getitem__(self, idx): + instance = self.minibatch[idx] + return instance["utt"], instance["feat"], instance["text"] diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 9ff95f29..8b1adcd0 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -247,11 +247,11 @@ class Trainer(): report("Rank", dist.get_rank()) report("epoch", self.epoch) report('step', self.iteration) - report('iter', batch_index + 1) - report('total',len(self.train_loader)) report("lr", self.lr_scheduler()) self.train_batch(batch_index, batch, msg) self.after_train_batch() + report('iter', batch_index + 1) + report('total', len(self.train_loader)) report('reader_cost', dataload_time) observation['batch_cost'] = observation[ 'reader_cost'] + observation['step_cost']