From 1e37e2cca36d8e8767c3020bb8cb586b2aa5839c Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 20 Apr 2021 07:15:25 +0000 Subject: [PATCH] multi gpu valid --- deepspeech/exps/deepspeech2/model.py | 8 +-- deepspeech/exps/u2/model.py | 100 +++++++++++++++------------ deepspeech/training/trainer.py | 20 +++++- examples/tiny/s1/local/train.sh | 8 ++- 4 files changed, 83 insertions(+), 53 deletions(-) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 5efe4b48d..e21a03f6f 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -67,7 +67,6 @@ class DeepSpeech2Trainer(Trainer): self.iteration) self.iteration += 1 - @mp_tools.rank_zero_only @paddle.no_grad() def valid(self): logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") @@ -84,11 +83,10 @@ class DeepSpeech2Trainer(Trainer): valid_losses['val_loss'].append(float(loss)) if (i + 1) % self.config.training.log_interval == 0: - valid_losses['val_history_loss'] = total_loss / num_seen_utts - - # write visual log valid_losses = {k: np.mean(v) for k, v in valid_losses.items()} + valid_losses['val_history_loss'] = total_loss / num_seen_utts + # logging msg = f"Valid: Rank: {dist.get_rank()}, " msg += "epoch: {}, ".format(self.epoch) @@ -103,6 +101,8 @@ class DeepSpeech2Trainer(Trainer): self.visualizer.add_scalar("valid/{}".format(k), v, self.iteration) + logger.info('Rank {} Val info val_loss {}'.format( + dist.get_rank(), total_loss / num_seen_utts)) return total_loss, num_seen_utts def setup_model(self): diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 142491f86..58076e4be 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -109,6 +109,43 @@ class U2Trainer(Trainer): self.visualizer.add_scalars("step", losses_np_v, self.iteration - 1) + @paddle.no_grad() + def valid(self): + self.model.eval() + logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") + valid_losses = defaultdict(list) + num_seen_utts = 1 + total_loss = 0.0 + for i, batch in enumerate(self.valid_loader): + loss, attention_loss, ctc_loss = self.model(*batch) + if paddle.isfinite(loss): + num_utts = batch[0].shape[0] + num_seen_utts += num_utts + total_loss += float(loss) * num_utts + valid_losses['val_loss'].append(float(loss)) + if attention_loss: + valid_losses['val_att_loss'].append(float(attention_loss)) + if ctc_loss: + valid_losses['val_ctc_loss'].append(float(ctc_loss)) + + if (i + 1) % self.config.training.log_interval == 0: + valid_losses = {k: np.mean(v) for k, v in valid_losses.items()} + + valid_losses['val_history_loss'] = total_loss / num_seen_utts + + # logging + msg = f"Valid: Rank: {dist.get_rank()}, " + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "batch : {}/{}, ".format(i + 1, len(self.valid_loader)) + msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in valid_losses.items()) + logger.info(msg) + + logger.info('Rank {} Val info val_loss {}'.format( + dist.get_rank(), total_loss / num_seen_utts)) + return total_loss, num_seen_utts + def train(self): """The training process control by step.""" # !!!IMPORTANT!!! @@ -148,53 +185,26 @@ class U2Trainer(Trainer): raise e total_loss, num_seen_utts = self.valid() - self.save( - tag=self.epoch, infos={'val_loss': total_loss / num_seen_utts}) + if dist.get_world_size() > 1: + num_seen_utts = paddle.to_tensor(num_seen_utts) + # the default operator in all_reduce function is sum. + dist.all_reduce(num_seen_utts) + total_loss = paddle.to_tensor(total_loss) + dist.all_reduce(total_loss) + cv_loss = total_loss / num_seen_utts + cv_loss = float(cv_loss) + else: + cv_loss = total_loss / num_seen_utts + + logger.info( + 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) + if self.visualizer: + self.visualizer.add_scalars( + 'epoch', {'cv_loss': cv_loss, + 'lr': self.lr_scheduler()}, self.epoch) + self.save(tag=self.epoch, infos={'val_loss': cv_loss}) self.new_epoch() - @mp_tools.rank_zero_only - @paddle.no_grad() - def valid(self): - self.model.eval() - logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") - valid_losses = defaultdict(list) - num_seen_utts = 1 - total_loss = 0.0 - for i, batch in enumerate(self.valid_loader): - loss, attention_loss, ctc_loss = self.model(*batch) - if paddle.isfinite(loss): - num_utts = batch[0].shape[0] - num_seen_utts += num_utts - total_loss += float(loss) * num_utts - valid_losses = {'val_loss': float(loss)} - if attention_loss: - valid_losses['val_att_loss'] = float(attention_loss) - if ctc_loss: - valid_losses['val_ctc_loss'] = float(ctc_loss) - - if (i + 1) % self.config.training.log_interval == 0: - valid_losses['val_history_loss'] = total_loss / num_seen_utts - - # write visual log - valid_losses = {k: np.mean(v) for k, v in valid_losses.items()} - - # logging - msg = f"Valid: Rank: {dist.get_rank()}, " - msg += "epoch: {}, ".format(self.epoch) - msg += "step: {}, ".format(self.iteration) - msg += "batch : {}/{}, ".format(i + 1, len(self.valid_loader)) - msg += ', '.join('{}: {:>.6f}'.format(k, v) - for k, v in valid_losses.items()) - logger.info(msg) - - if self.visualizer: - valid_losses_v = valid_losses.copy() - valid_losses_v.update({"lr": self.lr_scheduler()}) - self.visualizer.add_scalars('epoch', valid_losses_v, - self.epoch) - - return total_loss, num_seen_utts - def setup_dataloader(self): config = self.config.clone() config.defrost() diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index aded34624..128432aa9 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -202,7 +202,25 @@ class Trainer(): raise e total_loss, num_seen_utts = self.valid() - self.save(infos={'val_loss': total_loss / num_seen_utts}) + if dist.get_world_size() > 1: + num_seen_utts = paddle.to_tensor(num_seen_utts) + # the default operator in all_reduce function is sum. + dist.all_reduce(num_seen_utts) + total_loss = paddle.to_tensor(total_loss) + dist.all_reduce(total_loss) + cv_loss = total_loss / num_seen_utts + cv_loss = float(cv_loss) + else: + cv_loss = total_loss / num_seen_utts + + logger.info( + 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) + if self.visualizer: + self.visualizer.add_scalars( + 'epoch', {'cv_loss': cv_loss, + 'lr': self.lr_scheduler()}, self.epoch) + + self.save(infos={'val_loss': cv_loss}) self.lr_scheduler.step() self.new_epoch() diff --git a/examples/tiny/s1/local/train.sh b/examples/tiny/s1/local/train.sh index 6511614a9..a0598e17a 100644 --- a/examples/tiny/s1/local/train.sh +++ b/examples/tiny/s1/local/train.sh @@ -1,11 +1,13 @@ #! /usr/bin/env bash -CUDA_VISIBLE_DEVICES=0 \ +ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));') +echo "using $ngpu gpus..." + python3 -u ${BIN_DIR}/train.py \ --device 'gpu' \ ---nproc 1 \ +--nproc ${ngpu} \ --config conf/conformer.yaml \ ---output ckpt +--output ckpt-${1} if [ $? -ne 0 ]; then echo "Failed in training!"