Merge branch 'develop' into transformer

pull/556/head
Hui Zhang 5 years ago committed by GitHub
commit a0b6b00116
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,42 @@
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''
---
For support and discussions, please use our [Discourse forums](https://github.com/PaddlePaddle/DeepSpeech/discussions).
If you've found a bug then please create an issue with the following information:
**Describe the bug**
A clear and concise description of what the bug is.
**To Reproduce**
Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error
**Expected behavior**
A clear and concise description of what you expected to happen.
**Screenshots**
If applicable, add screenshots to help explain your problem.
** Environment (please complete the following information):**
- OS: [e.g. Ubuntu]
- GCC/G++ Version [e.g. 8.3]
- Python Version [e.g. 3.7]
- PaddlePaddle Version [e.g. 2.0.0]
- Model Version [e.g. 2.0.0]
- GPU/DRIVER Informationo [e.g. Tesla V100-SXM2-32GB/440.64.00]
- CUDA/CUDNN Version [e.g. cuda-10.2]
- MKL Version
- TensorRT Version
**Additional context**
Add any other context about the problem here.

@ -0,0 +1,24 @@
---
name: Feature request
about: Suggest an idea for this project
title: "[Feature request]"
labels: feature request
assignees: ''
---
For support and discussions, please use our [Discourse forums](https://github.com/PaddlePaddle/DeepSpeech/discussions).
If you've found a feature request then please create an issue with the following information:
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.

@ -18,30 +18,16 @@
* python>=3.7 * python>=3.7
* paddlepaddle>=2.0.0 * paddlepaddle>=2.0.0
- Run the setup script for the remaining dependencies Please see [install](docs/install.md).
```bash
git clone https://github.com/PaddlePaddle/DeepSpeech.git
cd DeepSpeech
pushd tools; make; popd
source tools/venv/bin/activate
bash setup.sh
```
- Source venv before do experiment.
```bash
source tools/venv/bin/activate
```
## Getting Started ## Getting Started
Please see [Getting Started](docs/geting_started.md) and [tiny egs](examples/tiny/README.md). Please see [Getting Started](docs/getting_started.md) and [tiny egs](examples/tiny/README.md).
## More Information ## More Information
* [Install](docs/install.md) * [Install](docs/install.md)
* [Getting Started](docs/geting_stared.md) * [Getting Started](docs/getting_started.md)
* [Data Prepration](docs/data_preparation.md) * [Data Prepration](docs/data_preparation.md)
* [Data Augmentation](docs/augmentation.md) * [Data Augmentation](docs/augmentation.md)
* [Ngram LM](docs/ngram_lm.md) * [Ngram LM](docs/ngram_lm.md)
@ -53,7 +39,7 @@ Please see [Getting Started](docs/geting_started.md) and [tiny egs](examples/tin
## Questions and Help ## Questions and Help
You are welcome to submit questions and bug reports in [Github Issues](https://github.com/PaddlePaddle/DeepSpeech/issues). You are also welcome to contribute to this project. You are welcome to submit questions in [Github Discussions](https://github.com/PaddlePaddle/DeepSpeech/discussions) and bug reports in [Github Issues](https://github.com/PaddlePaddle/DeepSpeech/issues). You are also welcome to contribute to this project.
## License ## License

@ -14,33 +14,20 @@
* [Baidu's Deep Speech2](http://proceedings.mlr.press/v48/amodei16.pdf) * [Baidu's Deep Speech2](http://proceedings.mlr.press/v48/amodei16.pdf)
## 安装 ## 安装
* python>=3.7 * python>=3.7
* paddlepaddle>=2.0.0 * paddlepaddle>=2.0.0
- 安装依赖 参看 [安装](docs/install.md)。
```bash
git clone https://github.com/PaddlePaddle/DeepSpeech.git
cd DeepSpeech
pushd tools; make; popd
source tools/venv/bin/activate
bash setup.sh
```
- 开始实验前要source环境.
```bash
source tools/venv/bin/activate
```
## 开始 ## 开始
请查看 [Getting Started](docs/geting_started.md) 和 [tiny egs](examples/tiny/README.md)。 请查看 [Getting Started](docs/getting_started.md) 和 [tiny egs](examples/tiny/README.md)。
## 更多信息 ## 更多信息
* [安装](docs/install.md) * [安装](docs/install.md)
* [开始](docs/geting_stared.md) * [开始](docs/getting_started.md)
* [数据处理](docs/data_preparation.md) * [数据处理](docs/data_preparation.md)
* [数据增强](docs/augmentation.md) * [数据增强](docs/augmentation.md)
* [语言模型](docs/ngram_lm.md) * [语言模型](docs/ngram_lm.md)
@ -51,7 +38,7 @@ source tools/venv/bin/activate
## 问题和帮助 ## 问题和帮助
欢迎您在[Github问题](https://github.com/PaddlePaddle/models/issues)中提交问题和bug。也欢迎您为这个项目做出贡献。 欢迎您在[Github讨论](https://github.com/PaddlePaddle/DeepSpeech/discussions)提交问题,[Github问题](https://github.com/PaddlePaddle/models/issues)中反馈bug。也欢迎您为这个项目做出贡献。
## License ## License

@ -39,7 +39,6 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.io.sampler import SortagradBatchSampler from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import ManifestDataset
from deepspeech.modules.loss import CTCLoss
from deepspeech.models.deepspeech2 import DeepSpeech2Model from deepspeech.models.deepspeech2 import DeepSpeech2Model
from deepspeech.models.deepspeech2 import DeepSpeech2InferModel from deepspeech.models.deepspeech2 import DeepSpeech2InferModel
@ -63,8 +62,6 @@ class DeepSpeech2Trainer(Trainer):
losses_np = { losses_np = {
'train_loss': float(loss), 'train_loss': float(loss),
'train_loss_div_batchsize':
float(loss) / self.config.data.batch_size
} }
msg = "Train: Rank: {}, ".format(dist.get_rank()) msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
@ -90,8 +87,6 @@ class DeepSpeech2Trainer(Trainer):
loss = self.model(*batch) loss = self.model(*batch)
valid_losses['val_loss'].append(float(loss)) valid_losses['val_loss'].append(float(loss))
valid_losses['val_loss_div_batchsize'].append(
float(loss) / self.config.data.batch_size)
# write visual log # write visual log
valid_losses = {k: np.mean(v) for k, v in valid_losses.items()} valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}

@ -170,7 +170,8 @@ class DeepSpeech2Model(nn.Layer):
odim=dict_size + 1, # <blank> is append after vocab odim=dict_size + 1, # <blank> is append after vocab
blank_id=dict_size, # last token is <blank> blank_id=dict_size, # last token is <blank>
dropout_rate=0.0, dropout_rate=0.0,
reduction=True) reduction=True, # sum
batch_average=True) # sum / batch_size
def forward(self, audio, text, audio_len, text_len): def forward(self, audio, text, audio_len, text_len):
"""Compute Model loss """Compute Model loss

@ -36,14 +36,16 @@ class CTCDecoder(nn.Layer):
odim, odim,
blank_id=0, blank_id=0,
dropout_rate: float=0.0, dropout_rate: float=0.0,
reduction: bool=True): reduction: bool=True,
batch_average: bool=False):
"""CTC decoder """CTC decoder
Args: Args:
enc_n_units ([int]): encoder output dimention enc_n_units ([int]): encoder output dimention
vocab_size ([int]): text vocabulary size vocab_size ([int]): text vocabulary size
dropout_rate (float): dropout rate (0.0 ~ 1.0) dropout_rate (float): dropout rate (0.0 ~ 1.0)
reduction (bool): reduce the CTC loss into a scalar reduction (bool): reduce the CTC loss into a scalar, True for 'sum' or 'none'
batch_average (bool): do batch dim wise average.
""" """
assert check_argument_types() assert check_argument_types()
super().__init__() super().__init__()
@ -53,7 +55,10 @@ class CTCDecoder(nn.Layer):
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.ctc_lo = nn.Linear(enc_n_units, self.odim) self.ctc_lo = nn.Linear(enc_n_units, self.odim)
reduction_type = "sum" if reduction else "none" reduction_type = "sum" if reduction else "none"
self.criterion = CTCLoss(blank=self.blank_id, reduction=reduction_type) self.criterion = CTCLoss(
blank=self.blank_id,
reduction=reduction_type,
batch_average=batch_average)
# CTCDecoder LM Score handle # CTCDecoder LM Score handle
self._ext_scorer = None self._ext_scorer = None

@ -25,32 +25,33 @@ __all__ = ['CTCLoss', "LabelSmoothingLoss"]
class CTCLoss(nn.Layer): class CTCLoss(nn.Layer):
def __init__(self, blank=0, reduction='sum'): def __init__(self, blank=0, reduction='sum', batch_average=False):
super().__init__() super().__init__()
# last token id as blank id # last token id as blank id
self.loss = nn.CTCLoss(blank=blank, reduction=reduction) self.loss = nn.CTCLoss(blank=blank, reduction=reduction)
self.batch_average = batch_average
def forward(self, logits, ys_pad, hlens, ys_lens): def forward(self, logits, ys_pad, hlens, ys_lens):
"""Compute CTC loss. """Compute CTC loss.
Args: Args:
logits ([paddle.Tensor]): [description] logits ([paddle.Tensor]): [B, Tmax, D]
ys_pad ([paddle.Tensor]): [description] ys_pad ([paddle.Tensor]): [B, Tmax]
hlens ([paddle.Tensor]): [description] hlens ([paddle.Tensor]): [B]
ys_lens ([paddle.Tensor]): [description] ys_lens ([paddle.Tensor]): [B]
Returns: Returns:
[paddle.Tensor]: scalar. If reduction is 'none', then (N), where N = \text{batch size}. [paddle.Tensor]: scalar. If reduction is 'none', then (N), where N = \text{batch size}.
""" """
B = paddle.shape(logits)[0]
# warp-ctc need logits, and do softmax on logits by itself # warp-ctc need logits, and do softmax on logits by itself
# warp-ctc need activation with shape [T, B, V + 1] # warp-ctc need activation with shape [T, B, V + 1]
# logits: (B, L, D) -> (L, B, D) # logits: (B, L, D) -> (L, B, D)
logits = logits.transpose([1, 0, 2]) logits = logits.transpose([1, 0, 2])
loss = self.loss(logits, ys_pad, hlens, ys_lens) loss = self.loss(logits, ys_pad, hlens, ys_lens)
if self.batch_average:
# wenet do batch-size average, deepspeech2 not do this
# Batch-size average # Batch-size average
# loss = loss / paddle.shape(logits)[1] loss = loss / B
return loss return loss

@ -1,6 +1,8 @@
# Prepare Language Model # Prepare Language Model
A language model is required to improve the decoder's performance. We have prepared two language models (with lossy compression) for users to download and try. One is for English and the other is for Mandarin. Users can simply run this to download the preprared language models: A language model is required to improve the decoder's performance. We have prepared two language models (with lossy compression) for users to download and try. One is for English and the other is for Mandarin. The bash script to download LM is example's `local/download_lm_*.sh`.
For example, users can simply run this to download the preprared mandarin language models:
```bash ```bash
cd examples/aishell cd examples/aishell
@ -8,7 +10,9 @@ source path.sh
bash local/download_lm_ch.sh bash local/download_lm_ch.sh
``` ```
If you wish to train your own better language model, please refer to [KenLM](https://github.com/kpu/kenlm) for tutorials. Here we provide some tips to show how we preparing our English and Mandarin language models. You can take it as a reference when you train your own. If you wish to train your own better language model, please refer to [KenLM](https://github.com/kpu/kenlm) for tutorials.
Here we provide some tips to show how we preparing our English and Mandarin language models.
You can take it as a reference when you train your own.
## English LM ## English LM

@ -2,3 +2,4 @@ data
ckpt* ckpt*
demo_cache demo_cache
*.log *.log
log

@ -1,7 +1,7 @@
# Aishell-1 # Aishell-1
## CTC ## CTC
| Model | Config | Test set | CER | | Model | Config | Test Set | CER | Valid Loss |
| --- | --- | --- | --- | | --- | --- | --- | --- | --- |
| DeepSpeech2 | conf/deepspeech2.yaml | test | 0.078977 | | DeepSpeech2 | conf/deepspeech2.yaml | test | 0.077249 | 7.036566 |
| DeepSpeech2 | release 1.8.5 | test | 0.080447 | | DeepSpeech2 | release 1.8.5 | test | 0.087004 | 8.575452 |

@ -29,8 +29,8 @@ model:
use_gru: True use_gru: True
share_rnn_weights: False share_rnn_weights: False
training: training:
n_epoch: 30 n_epoch: 50
lr: 5e-4 lr: 2e-3
lr_decay: 0.83 lr_decay: 0.83
weight_decay: 1e-06 weight_decay: 1e-06
global_grad_clip: 5.0 global_grad_clip: 5.0
@ -39,7 +39,7 @@ decoding:
error_rate_type: cer error_rate_type: cer
decoding_method: ctc_beam_search decoding_method: ctc_beam_search
lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm
alpha: 2.6 alpha: 1.9
beta: 5.0 beta: 5.0
beam_size: 300 beam_size: 300
cutoff_prob: 0.99 cutoff_prob: 0.99

@ -1,6 +1,6 @@
#! /usr/bin/env bash #! /usr/bin/env bash
if [[ $# != 1 ]]; if [[ $# != 1 ]]; then
echo "usage: $0 ckpt-path" echo "usage: $0 ckpt-path"
exit -1 exit -1
fi fi

@ -2,7 +2,7 @@
# train model # train model
# if you wish to resume from an exists model, uncomment --init_from_pretrained_model # if you wish to resume from an exists model, uncomment --init_from_pretrained_model
export FLAGS_sync_nccl_allreduce=0 #export FLAGS_sync_nccl_allreduce=0
ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));') ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."

@ -7,7 +7,7 @@ source path.sh
bash ./local/data.sh bash ./local/data.sh
# train model # train model
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./local/train.sh CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./local/train.sh baseline
# test model # test model
CUDA_VISIBLE_DEVICES=0 bash ./local/test.sh CUDA_VISIBLE_DEVICES=0 bash ./local/test.sh

@ -1,7 +1,7 @@
# LibriSpeech # LibriSpeech
## CTC ## CTC
| Model | Config | Test set | WER | | Model | Config | Test Set | WER | Valid Loss |
| --- | --- | --- | --- | | --- | --- | --- | --- | --- |
| DeepSpeech2 | conf/deepspeech2.yaml | test-clean | 0.073973 | | DeepSpeech2 | conf/deepspeech2.yaml | test-clean | 0.069357 | 15.078561 |
| DeepSpeech2 | release 1.8.5 | test-clean | 0.074939 | | DeepSpeech2 | release 1.8.5 | test-clean | 0.074939 | 15.351633 |

@ -29,8 +29,8 @@ model:
use_gru: False use_gru: False
share_rnn_weights: True share_rnn_weights: True
training: training:
n_epoch: 20 n_epoch: 50
lr: 5e-4 lr: 1e-3
lr_decay: 0.83 lr_decay: 0.83
weight_decay: 1e-06 weight_decay: 1e-06
global_grad_clip: 5.0 global_grad_clip: 5.0
@ -39,7 +39,7 @@ decoding:
error_rate_type: wer error_rate_type: wer
decoding_method: ctc_beam_search decoding_method: ctc_beam_search
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5 alpha: 1.9
beta: 0.3 beta: 0.3
beam_size: 500 beam_size: 500
cutoff_prob: 1.0 cutoff_prob: 1.0

@ -1,6 +1,6 @@
#! /usr/bin/env bash #! /usr/bin/env bash
if [[ $# != 1 ]]; if [[ $# != 1 ]];then
echo "usage: $0 ckpt-path" echo "usage: $0 ckpt-path"
exit -1 exit -1
fi fi

@ -1,8 +1,9 @@
#! /usr/bin/env bash #! /usr/bin/env bash
export FLAGS_sync_nccl_allreduce=0 #export FLAGS_sync_nccl_allreduce=0
# https://github.com/PaddlePaddle/Paddle/pull/28484 # https://github.com/PaddlePaddle/Paddle/pull/28484
export NCCL_SHM_DISABLE=1 #export NCCL_SHM_DISABLE=1
ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));') ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
@ -11,7 +12,7 @@ python3 -u ${BIN_DIR}/train.py \
--device 'gpu' \ --device 'gpu' \
--nproc ${ngpu} \ --nproc ${ngpu} \
--config conf/deepspeech2.yaml \ --config conf/deepspeech2.yaml \
--output ckpt --output ckpt-${1}
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Failed in training!" echo "Failed in training!"

@ -1,6 +1,6 @@
#! /usr/bin/env bash #! /usr/bin/env bash
if [[ $# != 1 ]]; if [[ $# != 1 ]];then
echo "usage: $0 ckpt-path" echo "usage: $0 ckpt-path"
exit -1 exit -1
fi fi

@ -6,7 +6,6 @@ if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
CUDA_VISIBLE_DEVICES=0 \
python3 -u ${BIN_DIR}/test.py \ python3 -u ${BIN_DIR}/test.py \
--device 'gpu' \ --device 'gpu' \
--nproc 1 \ --nproc 1 \

@ -2,7 +2,6 @@
export FLAGS_sync_nccl_allreduce=0 export FLAGS_sync_nccl_allreduce=0
CUDA_VISIBLE_DEVICES=0 \
python3 -u ${BIN_DIR}/train.py \ python3 -u ${BIN_DIR}/train.py \
--device 'gpu' \ --device 'gpu' \
--nproc 1 \ --nproc 1 \

@ -1,13 +1,19 @@
#! /usr/bin/env bash #! /usr/bin/env bash
source utils/log.sh
SUDO='sudo' SUDO='sudo'
if [ $(id -u) -eq 0 ]; then if [ $(id -u) -eq 0 ]; then
SUDO='' SUDO=''
fi fi
if [ -e /etc/lsb-release ];then if [ -e /etc/lsb-release ]; then
#${SUDO} apt-get update #${SUDO} apt-get update
${SUDO} apt-get install -y pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev ${SUDO} apt-get install -y pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev
else
error_msg "Please using Ubuntu or install `pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev` by user."
exit -1
fi fi
# install python dependencies # install python dependencies
@ -15,17 +21,17 @@ if [ -f "requirements.txt" ]; then
pip3 install -r requirements.txt pip3 install -r requirements.txt
fi fi
if [ $? != 0 ]; then if [ $? != 0 ]; then
echo "Install python dependencies failed !!!" error_msg "Install python dependencies failed !!!"
exit 1 exit 1
fi fi
# install package libsndfile # install package libsndfile
python3 -c "import soundfile" python3 -c "import soundfile"
if [ $? != 0 ]; then if [ $? != 0 ]; then
echo "Install package libsndfile into default system path." info_msg "Install package libsndfile into default system path."
wget "http://www.mega-nerd.com/libsndfile/files/libsndfile-1.0.28.tar.gz" wget "http://www.mega-nerd.com/libsndfile/files/libsndfile-1.0.28.tar.gz"
if [ $? != 0 ]; then if [ $? != 0 ]; then
echo "Download libsndfile-1.0.28.tar.gz failed !!!" error_msg "Download libsndfile-1.0.28.tar.gz failed !!!"
exit 1 exit 1
fi fi
tar -zxvf libsndfile-1.0.28.tar.gz tar -zxvf libsndfile-1.0.28.tar.gz
@ -43,6 +49,10 @@ if [ $? != 0 ]; then
sh setup.sh sh setup.sh
cd - > /dev/null cd - > /dev/null
fi fi
python3 -c "import pkg_resources; pkg_resources.require(\"swig_decoders==1.1\")"
if [ $? != 0 ]; then
error_msg "Please check why decoder install error!"
exit -1
fi
info_msg "Install all dependencies successfully."
echo "Install all dependencies successfully."

@ -0,0 +1,11 @@
_HDR_FMT="%.23s %s[%s]: "
_ERR_MSG_FMT="ERROR: ${_HDR_FMT}%s\n"
_INFO_MSG_FMT="INFO: ${_HDR_FMT}%s\n"
error_msg() {
printf "$_ERR_MSG_FMT" $(date +%F.%T.%N) ${BASH_SOURCE[1]##*/} ${BASH_LINENO[0]} "${@}"
}
info_msg() {
printf "$_INFO_MSG_FMT" $(date +%F.%T.%N) ${BASH_SOURCE[1]##*/} ${BASH_LINENO[0]} "${@}"
}
Loading…
Cancel
Save