fix egs bugs (#552)

* fix egs

* fix log
pull/570/head
Hui Zhang 4 years ago committed by GitHub
parent 4c8c2178af
commit 258307df9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -63,8 +63,6 @@ class DeepSpeech2Trainer(Trainer):
losses_np = { losses_np = {
'train_loss': float(loss), 'train_loss': float(loss),
'train_loss_div_batchsize':
float(loss) / self.config.data.batch_size
} }
msg = "Train: Rank: {}, ".format(dist.get_rank()) msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
@ -90,8 +88,6 @@ class DeepSpeech2Trainer(Trainer):
loss = self.model(*batch) loss = self.model(*batch)
valid_losses['val_loss'].append(float(loss)) valid_losses['val_loss'].append(float(loss))
valid_losses['val_loss_div_batchsize'].append(
float(loss) / self.config.data.batch_size)
# write visual log # write visual log
valid_losses = {k: np.mean(v) for k, v in valid_losses.items()} valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}

@ -62,14 +62,15 @@ class CTCLoss(nn.Layer):
"""Compute CTC loss. """Compute CTC loss.
Args: Args:
logits ([paddle.Tensor]): [description] logits ([paddle.Tensor]): [B, Tmax, D]
ys_pad ([paddle.Tensor]): [description] ys_pad ([paddle.Tensor]): [B, Tmax]
hlens ([paddle.Tensor]): [description] hlens ([paddle.Tensor]): [B]
ys_lens ([paddle.Tensor]): [description] ys_lens ([paddle.Tensor]): [B]
Returns: Returns:
[paddle.Tensor]: scalar. If reduction is 'none', then (N), where N = \text{batch size}. [paddle.Tensor]: scalar. If reduction is 'none', then (N), where N = \text{batch size}.
""" """
B = paddle.shape(logits)[0]
# warp-ctc need logits, and do softmax on logits by itself # warp-ctc need logits, and do softmax on logits by itself
# warp-ctc need activation with shape [T, B, V + 1] # warp-ctc need activation with shape [T, B, V + 1]
# logits: (B, L, D) -> (L, B, D) # logits: (B, L, D) -> (L, B, D)
@ -78,5 +79,5 @@ class CTCLoss(nn.Layer):
# wenet do batch-size average, deepspeech2 not do this # wenet do batch-size average, deepspeech2 not do this
# Batch-size average # Batch-size average
# loss = loss / paddle.shape(logits)[1] # loss = loss / B
return loss return loss

@ -2,3 +2,4 @@ data
ckpt* ckpt*
demo_cache demo_cache
*.log *.log
log

@ -1,6 +1,6 @@
#! /usr/bin/env bash #! /usr/bin/env bash
if [[ $# != 1 ]]; if [[ $# != 1 ]]; then
echo "usage: $0 ckpt-path" echo "usage: $0 ckpt-path"
exit -1 exit -1
fi fi

@ -7,7 +7,7 @@ source path.sh
bash ./local/data.sh bash ./local/data.sh
# train model # train model
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./local/train.sh CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./local/train.sh baseline
# test model # test model
CUDA_VISIBLE_DEVICES=0 bash ./local/test.sh CUDA_VISIBLE_DEVICES=0 bash ./local/test.sh

@ -1,6 +1,6 @@
#! /usr/bin/env bash #! /usr/bin/env bash
if [[ $# != 1 ]]; if [[ $# != 1 ]];then
echo "usage: $0 ckpt-path" echo "usage: $0 ckpt-path"
exit -1 exit -1
fi fi

@ -11,7 +11,7 @@ python3 -u ${BIN_DIR}/train.py \
--device 'gpu' \ --device 'gpu' \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config conf/deepspeech2.yaml \ --config conf/deepspeech2.yaml \
--output ckpt --output ckpt-${1}
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Failed in training!" echo "Failed in training!"

@ -1,6 +1,6 @@
#! /usr/bin/env bash #! /usr/bin/env bash
if [[ $# != 1 ]]; if [[ $# != 1 ]];then
echo "usage: $0 ckpt-path" echo "usage: $0 ckpt-path"
exit -1 exit -1
fi fi

@ -6,7 +6,6 @@ if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
CUDA_VISIBLE_DEVICES=0 \
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--device 'gpu' \ --device 'gpu' \
--nproc 1 \ --nproc 1 \

@ -2,7 +2,6 @@
export FLAGS_sync_nccl_allreduce=0 export FLAGS_sync_nccl_allreduce=0
CUDA_VISIBLE_DEVICES=0 \
python3 -u ${BIN_DIR}/train.py \ python3 -u ${BIN_DIR}/train.py \
--device 'gpu' \ --device 'gpu' \
--nproc 1 \ --nproc 1 \

Loading…
Cancel
Save