From 8c4f60be097175daf52e3888e1c4059e09368b13 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 9 Feb 2021 09:54:28 +0000 Subject: [PATCH] fix ctc cuda memcopy error --- examples/tiny/local/run_train.sh | 1 + model_utils/model.py | 4 ++-- model_utils/network.py | 8 ++++---- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/tiny/local/run_train.sh b/examples/tiny/local/run_train.sh index 7880a4bba..b197a9fd1 100644 --- a/examples/tiny/local/run_train.sh +++ b/examples/tiny/local/run_train.sh @@ -17,6 +17,7 @@ export FLAGS_sync_nccl_allreduce=0 #CUDA_VISIBLE_DEVICES=0,1,2,3 \ CUDA_VISIBLE_DEVICES=1,2,3 \ python3 -u ${MAIN_ROOT}/train.py \ +--device 'gpu' \ --nproc 1 \ --config conf/deepspeech2.yaml \ --output ckpt diff --git a/model_utils/model.py b/model_utils/model.py index d845eb3c2..2ef135d2a 100644 --- a/model_utils/model.py +++ b/model_utils/model.py @@ -149,8 +149,8 @@ class DeepSpeech2Trainer(Trainer): self.logger.info("Setup model/optimizer/criterion!") def compute_losses(self, inputs, outputs): - _, texts, logits_len, texts_len = inputs - logits = outputs + _, texts, _, texts_len = inputs + logits, logits_len = outputs loss = self.criterion(logits, texts, logits_len, texts_len) return loss diff --git a/model_utils/network.py b/model_utils/network.py index 83b91fb70..68bf4e7b3 100644 --- a/model_utils/network.py +++ b/model_utils/network.py @@ -517,11 +517,11 @@ class DeepSpeech2(nn.Layer): #ctcdecoder need probs, not log_probs probs = F.softmax(logits) - return logits, probs + return logits, probs, audio_len @paddle.no_grad() def infer(self, audio, audio_len): - _, probs = self.predict(audio, audio_len) + _, probs, audio_len = self.predict(audio, audio_len) return probs def forward(self, audio, text, audio_len, text_len): @@ -531,8 +531,8 @@ class DeepSpeech2(nn.Layer): audio_len: shape [B] text_len: shape [B] """ - logits, _ = self.predict(audio, audio_len) - return logits + logits, _, audio_len = self.predict(audio, audio_len) + return logits, audio_len class DeepSpeech2Loss(nn.Layer):