From 466672e1de6c77c1b18109332d1e1b72f341fa6a Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Sat, 2 Oct 2021 12:44:27 +0000 Subject: [PATCH] no_sync if paddle support else nullcontext --- deepspeech/exps/deepspeech2/model.py | 3 ++- deepspeech/exps/u2/model.py | 3 ++- deepspeech/exps/u2_kaldi/model.py | 3 ++- deepspeech/exps/u2_st/model.py | 3 ++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index e84de615..3dc8286d 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -87,7 +87,8 @@ class DeepSpeech2Trainer(Trainer): # Disable gradient synchronizations across DDP processes. # Within this context, gradients will be accumulated on module # variables, which will later be synchronized. - context = self.model.no_sync + context = self.model.no_sync if (hasattr(self.model, "no_sync") and + self.parallel) else nullcontext else: # Used for single gpu training and DDP gradient synchronization # processes. diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 9cb3fa3c..65ec5174 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -106,7 +106,8 @@ class U2Trainer(Trainer): # Within this context, gradients will be accumulated on module # variables, which will later be synchronized. # When using cpu w/o DDP, model does not have `no_sync` - context = self.model.no_sync if self.parallel else nullcontext + context = self.model.no_sync if (hasattr(self.model, "no_sync") and + self.parallel) else nullcontext else: # Used for single gpu training and DDP gradient synchronization # processes. diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index d38afe25..5a72e44d 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -105,7 +105,8 @@ class U2Trainer(Trainer): # Disable gradient synchronizations across DDP processes. # Within this context, gradients will be accumulated on module # variables, which will later be synchronized. - context = self.model.no_sync + context = self.model.no_sync if (hasattr(self.model, "no_sync") and + self.parallel) else nullcontext else: # Used for single gpu training and DDP gradient synchronization # processes. diff --git a/deepspeech/exps/u2_st/model.py b/deepspeech/exps/u2_st/model.py index c480499c..08060d97 100644 --- a/deepspeech/exps/u2_st/model.py +++ b/deepspeech/exps/u2_st/model.py @@ -110,7 +110,8 @@ class U2STTrainer(Trainer): # Disable gradient synchronizations across DDP processes. # Within this context, gradients will be accumulated on module # variables, which will later be synchronized. - context = self.model.no_sync + context = self.model.no_sync if (hasattr(self.model, "no_sync") and + self.parallel) else nullcontext else: # Used for single gpu training and DDP gradient synchronization # processes.