fix ctc cuda memcopy error

pull/522/head
Hui Zhang 5 years ago
parent 4dc75c40c9
commit 8c4f60be09

@ -17,6 +17,7 @@ export FLAGS_sync_nccl_allreduce=0
#CUDA_VISIBLE_DEVICES=0,1,2,3 \ #CUDA_VISIBLE_DEVICES=0,1,2,3 \
CUDA_VISIBLE_DEVICES=1,2,3 \ CUDA_VISIBLE_DEVICES=1,2,3 \
python3 -u ${MAIN_ROOT}/train.py \ python3 -u ${MAIN_ROOT}/train.py \
--device 'gpu' \
--nproc 1 \ --nproc 1 \
--config conf/deepspeech2.yaml \ --config conf/deepspeech2.yaml \
--output ckpt --output ckpt

@ -149,8 +149,8 @@ class DeepSpeech2Trainer(Trainer):
self.logger.info("Setup model/optimizer/criterion!") self.logger.info("Setup model/optimizer/criterion!")
def compute_losses(self, inputs, outputs): def compute_losses(self, inputs, outputs):
_, texts, logits_len, texts_len = inputs _, texts, _, texts_len = inputs
logits = outputs logits, logits_len = outputs
loss = self.criterion(logits, texts, logits_len, texts_len) loss = self.criterion(logits, texts, logits_len, texts_len)
return loss return loss

@ -517,11 +517,11 @@ class DeepSpeech2(nn.Layer):
#ctcdecoder need probs, not log_probs #ctcdecoder need probs, not log_probs
probs = F.softmax(logits) probs = F.softmax(logits)
return logits, probs return logits, probs, audio_len
@paddle.no_grad() @paddle.no_grad()
def infer(self, audio, audio_len): def infer(self, audio, audio_len):
_, probs = self.predict(audio, audio_len) _, probs, audio_len = self.predict(audio, audio_len)
return probs return probs
def forward(self, audio, text, audio_len, text_len): def forward(self, audio, text, audio_len, text_len):
@ -531,8 +531,8 @@ class DeepSpeech2(nn.Layer):
audio_len: shape [B] audio_len: shape [B]
text_len: shape [B] text_len: shape [B]
""" """
logits, _ = self.predict(audio, audio_len) logits, _, audio_len = self.predict(audio, audio_len)
return logits return logits, audio_len
class DeepSpeech2Loss(nn.Layer): class DeepSpeech2Loss(nn.Layer):

Loading…
Cancel
Save