diff --git a/data_utils/dataset.py b/data_utils/dataset.py index 658c5ba54..4a01b7298 100644 --- a/data_utils/dataset.py +++ b/data_utils/dataset.py @@ -237,9 +237,9 @@ class DeepSpeech2DistributedBatchSampler(DistributedBatchSampler): assert (clipped == False) if not clipped: res_len = len(indices) - shift_len - len(batch_indices) - assert res_len != 0, f"_batch_shuffle clipped {len(indices)} , {shift_len}, {len(batch_indices)}" # when res_len is 0, will return whole list, len(List[-0:]) = len(List[:]) - batch_indices.extend(indices[-res_len:]) + if res_len != 0: + batch_indices.extend(indices[-res_len:]) batch_indices.extend(indices[0:shift_len]) assert len(indices) == len( batch_indices @@ -381,9 +381,9 @@ class DeepSpeech2BatchSampler(BatchSampler): assert (clipped == False) if not clipped: res_len = len(indices) - shift_len - len(batch_indices) - assert res_len != 0, f"_batch_shuffle clipped {len(indices)} , {shift_len}, {len(batch_indices)}" # when res_len is 0, will return whole list, len(List[-0:]) = len(List[:]) - batch_indices.extend(indices[-res_len:]) + if res_len != 0: + batch_indices.extend(indices[-res_len:]) batch_indices.extend(indices[0:shift_len]) assert len(indices) == len( batch_indices @@ -532,108 +532,4 @@ class SpeechCollator(): audio_lens = np.array(audio_lens).astype('int64') texts = np.array(texts).astype('int32') text_lens = np.array(text_lens).astype('int64') - return padded_audios, texts, audio_lens, text_lens - - -# def create_dataloader(manifest_path, -# vocab_filepath, -# mean_std_filepath, -# augmentation_config='{}', -# max_duration=float('inf'), -# min_duration=0.0, -# stride_ms=10.0, -# window_ms=20.0, -# max_freq=None, -# specgram_type='linear', -# use_dB_normalization=True, -# random_seed=0, -# keep_transcription_text=False, -# is_training=False, -# batch_size=1, -# num_workers=0, -# sortagrad=False, -# shuffle_method=None, -# dist=False): - -# dataset = DeepSpeech2Dataset( -# manifest_path, -# vocab_filepath, -# mean_std_filepath, -# augmentation_config=augmentation_config, -# max_duration=max_duration, -# min_duration=min_duration, -# stride_ms=stride_ms, -# window_ms=window_ms, -# max_freq=max_freq, -# specgram_type=specgram_type, -# use_dB_normalization=use_dB_normalization, -# random_seed=random_seed, -# keep_transcription_text=keep_transcription_text) - -# if dist: -# batch_sampler = DeepSpeech2DistributedBatchSampler( -# dataset, -# batch_size, -# num_replicas=None, -# rank=None, -# shuffle=is_training, -# drop_last=is_training, -# sortagrad=is_training, -# shuffle_method=shuffle_method) -# else: -# batch_sampler = DeepSpeech2BatchSampler( -# dataset, -# shuffle=is_training, -# batch_size=batch_size, -# drop_last=is_training, -# sortagrad=is_training, -# shuffle_method=shuffle_method) - -# def padding_batch(batch, padding_to=-1, flatten=False, is_training=True): -# """ -# Padding audio features with zeros to make them have the same shape (or -# a user-defined shape) within one bach. - -# If ``padding_to`` is -1, the maximun shape in the batch will be used -# as the target shape for padding. Otherwise, `padding_to` will be the -# target shape (only refers to the second axis). - -# If `flatten` is True, features will be flatten to 1darray. -# """ -# new_batch = [] -# # get target shape -# max_length = max([audio.shape[1] for audio, text in batch]) -# if padding_to != -1: -# if padding_to < max_length: -# raise ValueError("If padding_to is not -1, it should be larger " -# "than any instance's shape in the batch") -# max_length = padding_to -# max_text_length = max([len(text) for audio, text in batch]) -# # padding -# padded_audios = [] -# audio_lens = [] -# texts, text_lens = [], [] -# for audio, text in batch: -# padded_audio = np.zeros([audio.shape[0], max_length]) -# padded_audio[:, :audio.shape[1]] = audio -# if flatten: -# padded_audio = padded_audio.flatten() -# padded_audios.append(padded_audio) -# audio_lens.append(audio.shape[1]) -# padded_text = np.zeros([max_text_length]) -# padded_text[:len(text)] = text -# texts.append(padded_text) -# text_lens.append(len(text)) - -# padded_audios = np.array(padded_audios).astype('float32') -# audio_lens = np.array(audio_lens).astype('int64') -# texts = np.array(texts).astype('int32') -# text_lens = np.array(text_lens).astype('int64') -# return padded_audios, texts, audio_lens, text_lens - -# loader = DataLoader( -# dataset, -# batch_sampler=batch_sampler, -# collate_fn=partial(padding_batch, is_training=is_training), -# num_workers=num_workers, ) -# return loader + return padded_audios, texts, audio_lens, text_lens \ No newline at end of file diff --git a/examples/aishell/conf/deepspeech2.yaml b/examples/aishell/conf/deepspeech2.yaml index c85cf3480..552d114c5 100644 --- a/examples/aishell/conf/deepspeech2.yaml +++ b/examples/aishell/conf/deepspeech2.yaml @@ -39,13 +39,13 @@ training: valid_interval: 1000 decoding: batch_size: 128 - error_rate_type: wer + error_rate_type: cer decoding_method: ctc_beam_search - lang_model_path: models/lm/common_crawl_00.prune01111.trie.klm - alpha: 2.5 - beta: 0.3 - beam_size: 500 - cutoff_prob: 1.0 + lang_model_path: models/lm/zh_giga.no_cna_cmn.prune01244.klm + alpha: 2.6 + beta: 5.0 + beam_size: 300 + cutoff_prob: 0.99 cutoff_top_n: 40 num_proc_bsearch: 8 diff --git a/examples/aishell/local/run_test.sh b/examples/aishell/local/run_test.sh index d2dbfb4f0..1015799b5 100644 --- a/examples/aishell/local/run_test.sh +++ b/examples/aishell/local/run_test.sh @@ -9,30 +9,12 @@ fi cd - > /dev/null -# evaluate model -CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ +CUDA_VISIBLE_DEVICES=6 \ python3 -u ${MAIN_ROOT}/test.py \ ---batch_size=128 \ ---beam_size=300 \ ---num_proc_bsearch=8 \ ---num_conv_layers=2 \ ---num_rnn_layers=3 \ ---rnn_layer_size=1024 \ ---alpha=2.6 \ ---beta=5.0 \ ---cutoff_prob=0.99 \ ---cutoff_top_n=40 \ ---use_gru=True \ ---use_gpu=True \ ---share_rnn_weights=False \ ---test_manifest="data/manifest.test" \ ---mean_std_path="data/mean_std.npz" \ ---vocab_path="data/vocab.txt" \ ---model_path="checkpoints/step_final" \ ---lang_model_path="${MAIN_ROOT}/models/lm/zh_giga.no_cna_cmn.prune01244.klm" \ ---decoding_method="ctc_beam_search" \ ---error_rate_type="cer" \ ---specgram_type="linear" +--device 'gpu' \ +--nproc 1 \ +--config conf/deepspeech2.yaml \ +--output ckpt if [ $? -ne 0 ]; then echo "Failed in evaluation!" diff --git a/examples/aishell/local/run_train.sh b/examples/aishell/local/run_train.sh index e3e8c745e..507562a29 100644 --- a/examples/aishell/local/run_train.sh +++ b/examples/aishell/local/run_train.sh @@ -32,7 +32,7 @@ export FLAGS_sync_nccl_allreduce=0 #--specgram_type="linear" \ #--shuffle_method="batch_shuffle_clipped" \ -CUDA_VISIBLE_DEVICES=1,2,6,7 \ +CUDA_VISIBLE_DEVICES=2,3,5,7 \ python3 -u ${MAIN_ROOT}/train.py \ --device 'gpu' \ --nproc 4 \ diff --git a/examples/tiny/conf/deepspeech2.yaml b/examples/tiny/conf/deepspeech2.yaml index 4aa8b8e90..457a56b2e 100644 --- a/examples/tiny/conf/deepspeech2.yaml +++ b/examples/tiny/conf/deepspeech2.yaml @@ -25,8 +25,8 @@ data: model: num_conv_layers: 2 num_rnn_layers: 3 - rnn_layer_size: 2048 - use_gru: False + rnn_layer_size: 2048 + use_gru: True share_rnn_weights: True training: n_epoch: 20 diff --git a/examples/tiny/local/run_test.sh b/examples/tiny/local/run_test.sh index 2df608895..cfedd1ca8 100644 --- a/examples/tiny/local/run_test.sh +++ b/examples/tiny/local/run_test.sh @@ -8,7 +8,7 @@ if [ $? -ne 0 ]; then fi cd - > /dev/null -CUDA_VISIBLE_DEVICES=0,1,2,3 \ +CUDA_VISIBLE_DEVICES=0 \ python3 -u ${MAIN_ROOT}/test.py \ --device 'gpu' \ --nproc 1 \ diff --git a/examples/tiny/local/run_train.sh b/examples/tiny/local/run_train.sh index 7037c07e3..9c81e49b5 100644 --- a/examples/tiny/local/run_train.sh +++ b/examples/tiny/local/run_train.sh @@ -3,8 +3,7 @@ export FLAGS_sync_nccl_allreduce=0 #CUDA_VISIBLE_DEVICES=0,1,2,3 \ -#CUDA_VISIBLE_DEVICES=0,4,5,6 \ -CUDA_VISIBLE_DEVICES=0 \ +CUDA_VISIBLE_DEVICES=0,1 \ python3 -u ${MAIN_ROOT}/train.py \ --device 'gpu' \ --nproc 1 \ diff --git a/model_utils/model.py b/model_utils/model.py index a48307863..f115028d0 100644 --- a/model_utils/model.py +++ b/model_utils/model.py @@ -68,32 +68,18 @@ class DeepSpeech2Trainer(Trainer): loss = self.criterion(logits, texts, logits_len, texts_len) return loss - def read_batch(self): - """Read a batch from the train_loader. - Returns - ------- - List[Tensor] - A batch. - """ - try: - batch = next(self.iterator) - except StopIteration as e: - raise e - return batch - - def train_batch(self): + def train_batch(self, batch_data): start = time.time() - batch = self.read_batch() - data_loader_time = time.time() - start - - self.optimizer.clear_grad() self.model.train() - audio, text, audio_len, text_len = batch - batch_size = audio.shape[0] + + audio, text, audio_len, text_len = batch_data outputs = self.model(audio, text, audio_len, text_len) - loss = self.compute_losses(batch, outputs) + loss = self.compute_losses(batch_data, outputs) + loss.backward() self.optimizer.step() + self.optimizer.clear_grad() + iteration_time = time.time() - start losses_np = { @@ -104,13 +90,9 @@ class DeepSpeech2Trainer(Trainer): msg = "Train: Rank: {}, ".format(dist.get_rank()) msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) - msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time, - iteration_time) - msg += f"batch size: {batch_size}, " + msg += "time: {:>.3f}s, ".format(iteration_time) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in losses_np.items()) - - #if self.iteration % 100 == 0: self.logger.info(msg) if dist.get_rank() == 0 and self.visualizer: @@ -118,6 +100,14 @@ class DeepSpeech2Trainer(Trainer): self.visualizer.add_scalar("train/{}".format(k), v, self.iteration) + def new_epoch(self): + """Reset the train loader and increment ``epoch``. + """ + if self.parallel: + # batch sampler epoch start from 0 + self.train_loader.batch_sampler.set_epoch(self.epoch) + self.epoch += 1 + def train(self): """The training process. @@ -126,21 +116,20 @@ class DeepSpeech2Trainer(Trainer): """ self.new_epoch() while self.epoch <= self.config.training.n_epoch: - try: + for batch in self.train_loader: self.iteration += 1 - self.train_batch() + self.train_batch(batch) # if self.iteration % self.config.training.valid_interval == 0: # self.valid() # if self.iteration % self.config.training.save_interval == 0: # self.save() - except StopIteration: - self.iteration -= 1 #epoch end, iteration ahead 1 - self.valid() - self.save() - self.lr_scheduler.step() - self.new_epoch() + + self.valid() + self.save() + self.lr_scheduler.step() + self.new_epoch() def compute_metrics(self, inputs, outputs): pass @@ -152,14 +141,13 @@ class DeepSpeech2Trainer(Trainer): valid_losses = defaultdict(list) for i, batch in enumerate(self.valid_loader): audio, text, audio_len, text_len = batch - batch_size = audio.shape[0] outputs = self.model(audio, text, audio_len, text_len) loss = self.compute_losses(batch, outputs) metrics = self.compute_metrics(batch, outputs) valid_losses['val_loss'].append(float(loss)) valid_losses['val_loss_div_batchsize'].append( - float(loss) / batch_size) + float(loss) / self.config.data.batch_size) # write visual log valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}