merge develop and fix train by step

pull/578/head
Hui Zhang 5 years ago
commit b5bbfc5e24

@ -0,0 +1,42 @@
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''
---
For support and discussions, please use our [Discourse forums](https://github.com/PaddlePaddle/DeepSpeech/discussions).
If you've found a bug then please create an issue with the following information:
**Describe the bug**
A clear and concise description of what the bug is.
**To Reproduce**
Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error
**Expected behavior**
A clear and concise description of what you expected to happen.
**Screenshots**
If applicable, add screenshots to help explain your problem.
** Environment (please complete the following information):**
- OS: [e.g. Ubuntu]
- GCC/G++ Version [e.g. 8.3]
- Python Version [e.g. 3.7]
- PaddlePaddle Version [e.g. 2.0.0]
- Model Version [e.g. 2.0.0]
- GPU/DRIVER Informationo [e.g. Tesla V100-SXM2-32GB/440.64.00]
- CUDA/CUDNN Version [e.g. cuda-10.2]
- MKL Version
- TensorRT Version
**Additional context**
Add any other context about the problem here.

@ -0,0 +1,24 @@
---
name: Feature request
about: Suggest an idea for this project
title: "[Feature request]"
labels: feature request
assignees: ''
---
For support and discussions, please use our [Discourse forums](https://github.com/PaddlePaddle/DeepSpeech/discussions).
If you've found a feature request then please create an issue with the following information:
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.

@ -11,6 +11,13 @@ abort(){
unittest(){
cd $1 > /dev/null
if [ -f "setup.sh" ]; then
bash setup.sh
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
fi
if [ $? != 0 ]; then
exit 1
fi
find . -path ./tools/venv -prune -false -o -name 'tests' -type d -print0 | \
xargs -0 -I{} -n1 bash -c \
'python3 -m unittest discover -v -s {}'
@ -19,6 +26,15 @@ unittest(){
coverage(){
cd $1 > /dev/null
if [ -f "setup.sh" ]; then
bash setup.sh
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
fi
if [ $? != 0 ]; then
exit 1
fi
find . -path ./tools/venv -prune -false -o -name 'tests' -type d -print0 | \
xargs -0 -I{} -n1 bash -c \
'python3 -m coverage run --branch {}'

@ -21,26 +21,13 @@
* python>=3.7
* paddlepaddle>=2.0.0
- Run the setup script for the remaining dependencies
```bash
git clone https://github.com/PaddlePaddle/DeepSpeech.git
cd DeepSpeech
pushd tools; make; popd
source tools/venv/bin/activate
bash setup.sh
```
- Source venv before do experiment.
```bash
source tools/venv/bin/activate
```
Please see [install](docs/install.md).
## Getting Started
Please see [Getting Started](docs/src/geting_started.md) and [tiny egs](examples/tiny/README.md).
## More Information
* [Install](docs/src/install.md)
@ -56,7 +43,7 @@ Please see [Getting Started](docs/src/geting_started.md) and [tiny egs](examples
## Questions and Help
You are welcome to submit questions and bug reports in [Github Issues](https://github.com/PaddlePaddle/DeepSpeech/issues). You are also welcome to contribute to this project.
You are welcome to submit questions in [Github Discussions](https://github.com/PaddlePaddle/DeepSpeech/discussions) and bug reports in [Github Issues](https://github.com/PaddlePaddle/DeepSpeech/issues). You are also welcome to contribute to this project.
## License

@ -18,24 +18,11 @@
## 安装
* python>=3.7
* paddlepaddle>=2.0.0
- 安装依赖
```bash
git clone https://github.com/PaddlePaddle/DeepSpeech.git
cd DeepSpeech
pushd tools; make; popd
source tools/venv/bin/activate
bash setup.sh
```
- 开始实验前要source环境.
```bash
source tools/venv/bin/activate
```
参看 [安装](docs/install.md)。
## 开始
@ -55,7 +42,7 @@ source tools/venv/bin/activate
## 问题和帮助
欢迎您在[Github问题](https://github.com/PaddlePaddle/models/issues)中提交问题和bug。也欢迎您为这个项目做出贡献。
欢迎您在[Github讨论](https://github.com/PaddlePaddle/DeepSpeech/discussions)提交问题,[Github问题](https://github.com/PaddlePaddle/models/issues)中反馈bug。也欢迎您为这个项目做出贡献。
## License

@ -58,8 +58,6 @@ class DeepSpeech2Trainer(Trainer):
losses_np = {
'train_loss': float(loss),
'train_loss_div_batchsize':
float(loss) / self.config.data.batch_size
}
msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
@ -85,8 +83,6 @@ class DeepSpeech2Trainer(Trainer):
loss = self.model(*batch)
valid_losses['val_loss'].append(float(loss))
valid_losses['val_loss_div_batchsize'].append(
float(loss) / self.config.data.batch_size)
# write visual log
valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}
@ -265,7 +261,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
self.logger.info(msg)
def run_test(self):
self.resume_or_load()
self.resume_or_scratch()
try:
self.test()
except KeyboardInterrupt:

@ -80,12 +80,15 @@ class U2Trainer(Trainer):
self.model.train()
start = time.time()
loss, attention_loss, ctc_loss = self.model(*batch_data)
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
if self.iteration % train_conf.accum_grad == 0:
self.optimizer.step()
self.optimizer.clear_grad()
self.lr_scheduler.step()
iteration_time = time.time() - start
@ -102,11 +105,49 @@ class U2Trainer(Trainer):
if self.iteration % train_conf.log_interval == 0:
self.logger.info(msg)
# display
if dist.get_rank() == 0 and self.visualizer:
for k, v in losses_np.items():
self.visualizer.add_scalar("train/{}".format(k), v,
self.iteration)
def train(self):
"""The training process.
It includes forward/backward/update and periodical validation and
saving.
"""
# !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine
# the code to satisfy the script export requirements
# script_model = paddle.jit.to_static(self.model)
# script_model_path = str(self.checkpoint_dir / 'init')
# paddle.jit.save(script_model, script_model_path)
from_scratch = self.resume_or_scratch()
self.logger.info(
f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch <= self.config.training.n_epoch:
try:
data_start_time = time.time()
for batch in self.train_loader:
dataload_time = time.time() - data_start_time
msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "lr: {}, ".foramt(self.lr_scheduler())
msg += "dataloader time: {:>.3f}s, ".format(dataload_time)
self.iteration += 1
self.train_batch(batch, msg)
data_start_time = time.time()
except Exception as e:
self.logger.error(e)
raise e
self.valid()
self.save()
self.new_epoch()
@mp_tools.rank_zero_only
@paddle.no_grad()
def valid(self):
@ -365,7 +406,7 @@ class U2Tester(U2Trainer):
self.logger.info(msg)
def run_test(self):
self.resume_or_load()
self.resume_or_scratch()
try:
self.test()
except KeyboardInterrupt:

@ -159,7 +159,8 @@ class DeepSpeech2Model(nn.Layer):
enc_n_units=self.encoder.output_size,
blank_id=dict_size, # last token is <blank>
dropout_rate=0.0,
reduction=True)
reduction=True, # sum
batch_average=True) # sum / batch_size
def forward(self, audio, audio_len, text, text_len):
"""Compute Model loss

@ -834,7 +834,14 @@ class U2Model(U2BaseModel):
decoder = TransformerDecoder(vocab_size,
encoder.output_size(),
**configs['decoder_conf'])
ctc = CTCDecoder(vocab_size, encoder.output_size())
ctc = CTCDecoder(
odim=vocab_size,
enc_n_units=encoder.output_size(),
blank_id=0,
dropout_rate=0.0,
reduction=True, # sum
batch_average=True) # sum / batch_size
return vocab_size, encoder, decoder, ctc
@classmethod

@ -37,14 +37,16 @@ class CTCDecoder(nn.Layer):
enc_n_units,
blank_id=0,
dropout_rate: float=0.0,
reduction: bool=True):
reduction: bool=True,
batch_average: bool=True):
"""CTC decoder
Args:
odim ([int]): text vocabulary size
enc_n_units ([int]): encoder output dimention
dropout_rate (float): dropout rate (0.0 ~ 1.0)
reduction (bool): reduce the CTC loss into a scalar
reduction (bool): reduce the CTC loss into a scalar, True for 'sum' or 'none'
batch_average (bool): do batch dim wise average.
"""
assert check_argument_types()
super().__init__()
@ -54,7 +56,10 @@ class CTCDecoder(nn.Layer):
self.dropout_rate = dropout_rate
self.ctc_lo = nn.Linear(enc_n_units, self.odim)
reduction_type = "sum" if reduction else "none"
self.criterion = CTCLoss(blank=self.blank_id, reduction=reduction_type)
self.criterion = CTCLoss(
blank=self.blank_id,
reduction=reduction_type,
batch_average=batch_average)
# CTCDecoder LM Score handle
self._ext_scorer = None

@ -24,32 +24,33 @@ __all__ = ['CTCLoss', "LabelSmoothingLoss"]
class CTCLoss(nn.Layer):
def __init__(self, blank=0, reduction='sum'):
def __init__(self, blank=0, reduction='sum', batch_average=False):
super().__init__()
# last token id as blank id
self.loss = nn.CTCLoss(blank=blank, reduction=reduction)
self.batch_average = batch_average
def forward(self, logits, ys_pad, hlens, ys_lens):
"""Compute CTC loss.
Args:
logits ([paddle.Tensor]): [description]
ys_pad ([paddle.Tensor]): [description]
hlens ([paddle.Tensor]): [description]
ys_lens ([paddle.Tensor]): [description]
logits ([paddle.Tensor]): [B, Tmax, D]
ys_pad ([paddle.Tensor]): [B, Tmax]
hlens ([paddle.Tensor]): [B]
ys_lens ([paddle.Tensor]): [B]
Returns:
[paddle.Tensor]: scalar. If reduction is 'none', then (N), where N = \text{batch size}.
"""
B = paddle.shape(logits)[0]
# warp-ctc need logits, and do softmax on logits by itself
# warp-ctc need activation with shape [T, B, V + 1]
# logits: (B, L, D) -> (L, B, D)
logits = logits.transpose([1, 0, 2])
loss = self.loss(logits, ys_pad, hlens, ys_lens)
# wenet do batch-size average, deepspeech2 not do this
# Batch-size average
# loss = loss / paddle.shape(logits)[1]
if self.batch_average:
# Batch-size average
loss = loss / B
return loss

@ -54,4 +54,4 @@ class WarmupLR(LRScheduler):
step_num**-0.5, step_num * self.warmup_steps**-1.5)
def set_step(self, step: int):
self.last_epoch = step
self.step(step)

@ -139,7 +139,7 @@ class Trainer():
checkpoint.save_parameters(self.checkpoint_dir, self.iteration,
self.model, self.optimizer, infos)
def resume_or_load(self):
def resume_or_scratch(self):
"""Resume from latest checkpoint at checkpoints in the output
directory or load a specified checkpoint.
@ -152,8 +152,20 @@ class Trainer():
checkpoint_dir=self.checkpoint_dir,
checkpoint_path=self.args.checkpoint_path)
if infos:
# restore from ckpt
self.iteration = infos["step"]
self.epoch = infos["epoch"]
self.lr_scheduler.step(self.iteration)
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)
return False
else:
# from scratch, epoch and iteration init with zero
# save init model, i.e. 0 epoch
self.save()
# self.epoch start from 1.
self.new_epoch()
return True
def new_epoch(self):
"""Reset the train loader and increment ``epoch``.
@ -166,22 +178,22 @@ class Trainer():
def train(self):
"""The training process.
It includes forward/backward/update and periodical validation and
saving.
"""
from_scratch = self.resume_or_scratch()
self.logger.info(
f"Train Total Examples: {len(self.train_loader.dataset)}")
self.new_epoch()
while self.epoch <= self.config.training.n_epoch:
try:
data_start_time = time.time()
for batch in self.train_loader:
dataload_time = time.time() - data_start_time
# iteration start from 1.
self.iteration += 1
msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "dataloader time: {:>.3f}s, ".format(dataload_time)
self.iteration += 1
self.train_batch(batch, msg)
data_start_time = time.time()
except Exception as e:
@ -190,6 +202,7 @@ class Trainer():
self.valid()
self.save()
# lr control by epoch
self.lr_scheduler.step()
self.new_epoch()
@ -197,7 +210,6 @@ class Trainer():
"""The routine of the experiment after setup. This method is intended
to be used by the user.
"""
self.resume_or_load()
try:
self.train()
except KeyboardInterrupt:
@ -298,7 +310,7 @@ class Trainer():
# global logger
stdout = False
save_path = log_file
save_path = str(log_file)
logging.basicConfig(
level=logging.DEBUG if stdout else logging.INFO,
format=format,

@ -45,7 +45,7 @@ source tools/venv/bin/activate
## Running in Docker Container (optional)
Docker is an open source tool to build, ship, and run distributed applications in an isolated environment. A Docker image for this project has been provided in [hub.docker.com](https://hub.docker.com) with all the dependencies installed, including the pre-built PaddlePaddle, CTC decoders, and other necessary Python and third-party packages. This Docker image requires the support of NVIDIA GPU, so please make sure its availiability and the [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) has been installed.
Docker is an open source tool to build, ship, and run distributed applications in an isolated environment. A Docker image for this project has been provided in [hub.docker.com](https://hub.docker.com) with all the dependencies installed. This Docker image requires the support of NVIDIA GPU, so please make sure its availiability and the [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) has been installed.
Take several steps to launch the Docker image:
@ -79,3 +79,7 @@ For example, for CUDA 10.1, CuDNN7.5 install paddle 2.0.0:
```bash
python3 -m pip install paddlepaddle-gpu==2.0.0
```
- Install Deepspeech
Please see [Setup](#setup) section.

@ -1,6 +1,8 @@
# Prepare Language Model
A language model is required to improve the decoder's performance. We have prepared two language models (with lossy compression) for users to download and try. One is for English and the other is for Mandarin. Users can simply run this to download the preprared language models:
A language model is required to improve the decoder's performance. We have prepared two language models (with lossy compression) for users to download and try. One is for English and the other is for Mandarin. The bash script to download LM is example's `local/download_lm_*.sh`.
For example, users can simply run this to download the preprared mandarin language models:
```bash
cd examples/aishell
@ -8,7 +10,9 @@ source path.sh
bash local/download_lm_ch.sh
```
If you wish to train your own better language model, please refer to [KenLM](https://github.com/kpu/kenlm) for tutorials. Here we provide some tips to show how we preparing our English and Mandarin language models. You can take it as a reference when you train your own.
If you wish to train your own better language model, please refer to [KenLM](https://github.com/kpu/kenlm) for tutorials.
Here we provide some tips to show how we preparing our English and Mandarin language models.
You can take it as a reference when you train your own.
## English LM

@ -2,3 +2,4 @@ data
ckpt*
demo_cache
*.log
log

@ -29,8 +29,8 @@ model:
use_gru: True
share_rnn_weights: False
training:
n_epoch: 30
lr: 5e-4
n_epoch: 50
lr: 2e-3
lr_decay: 0.83
weight_decay: 1e-06
global_grad_clip: 5.0
@ -39,7 +39,7 @@ decoding:
error_rate_type: cer
decoding_method: ctc_beam_search
lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm
alpha: 2.6
alpha: 1.9
beta: 5.0
beam_size: 300
cutoff_prob: 0.99

@ -1,6 +1,6 @@
#! /usr/bin/env bash
if [[ $# != 1 ]];
if [[ $# != 1 ]]; then
echo "usage: $0 ckpt-path"
exit -1
fi

@ -2,7 +2,7 @@
# train model
# if you wish to resume from an exists model, uncomment --init_from_pretrained_model
export FLAGS_sync_nccl_allreduce=0
#export FLAGS_sync_nccl_allreduce=0
ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));')
echo "using $ngpu gpus..."

@ -7,7 +7,7 @@ source path.sh
bash ./local/data.sh
# train model
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./local/train.sh
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./local/train.sh baseline
# test model
CUDA_VISIBLE_DEVICES=0 bash ./local/test.sh
@ -16,4 +16,4 @@ CUDA_VISIBLE_DEVICES=0 bash ./local/test.sh
CUDA_VISIBLE_DEVICES=0 bash ./local/infer.sh ckpt/checkpoints/step-3284
# export model
bash ./local/export.sh ckpt/checkpoints/step-3284 jit.model
bash ./local/export.sh ckpt/checkpoints/step-3284 jit.model

@ -1 +1 @@
* s0 for deepspeech2
* s0 is for deepspeech

@ -29,8 +29,8 @@ model:
use_gru: False
share_rnn_weights: True
training:
n_epoch: 20
lr: 5e-4
n_epoch: 50
lr: 1e-3
lr_decay: 0.83
weight_decay: 1e-06
global_grad_clip: 5.0
@ -39,7 +39,7 @@ decoding:
error_rate_type: wer
decoding_method: ctc_beam_search
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5
alpha: 1.9
beta: 0.3
beam_size: 500
cutoff_prob: 1.0

@ -1,6 +1,6 @@
#! /usr/bin/env bash
if [[ $# != 1 ]];
if [[ $# != 1 ]];then
echo "usage: $0 ckpt-path"
exit -1
fi

@ -1,8 +1,9 @@
#! /usr/bin/env bash
export FLAGS_sync_nccl_allreduce=0
#export FLAGS_sync_nccl_allreduce=0
# https://github.com/PaddlePaddle/Paddle/pull/28484
export NCCL_SHM_DISABLE=1
#export NCCL_SHM_DISABLE=1
ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));')
echo "using $ngpu gpus..."
@ -11,7 +12,7 @@ python3 -u ${BIN_DIR}/train.py \
--device 'gpu' \
--nproc ${ngpu} \
--config conf/deepspeech2.yaml \
--output ckpt
--output ckpt-${1}
if [ $? -ne 0 ]; then
echo "Failed in training!"

@ -1,6 +1,6 @@
#! /usr/bin/env bash
if [[ $# != 1 ]];
if [[ $# != 1 ]];then
echo "usage: $0 ckpt-path"
exit -1
fi

@ -6,7 +6,6 @@ if [ $? -ne 0 ]; then
exit 1
fi
CUDA_VISIBLE_DEVICES=0 \
python3 -u ${BIN_DIR}/test.py \
--device 'gpu' \
--nproc 1 \

@ -2,7 +2,6 @@
export FLAGS_sync_nccl_allreduce=0
CUDA_VISIBLE_DEVICES=0 \
python3 -u ${BIN_DIR}/train.py \
--device 'gpu' \
--nproc 1 \

@ -1,5 +1,8 @@
#! /usr/bin/env bash
source utils/log.sh
SUDO='sudo'
if [ $(id -u) -eq 0 ]; then
SUDO=''
@ -8,6 +11,8 @@ fi
if [ -e /etc/lsb-release ];then
#${SUDO} apt-get update
${SUDO} apt-get install -y sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev
error_msg "Please using Ubuntu or install `pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev` by user."
exit -1
fi
# install python dependencies
@ -15,17 +20,17 @@ if [ -f "requirements.txt" ]; then
pip3 install -r requirements.txt
fi
if [ $? != 0 ]; then
echo "Install python dependencies failed !!!"
error_msg "Install python dependencies failed !!!"
exit 1
fi
# install package libsndfile
python3 -c "import soundfile"
if [ $? != 0 ]; then
echo "Install package libsndfile into default system path."
info_msg "Install package libsndfile into default system path."
wget "http://www.mega-nerd.com/libsndfile/files/libsndfile-1.0.28.tar.gz"
if [ $? != 0 ]; then
echo "Download libsndfile-1.0.28.tar.gz failed !!!"
error_msg "Download libsndfile-1.0.28.tar.gz failed !!!"
exit 1
fi
tar -zxvf libsndfile-1.0.28.tar.gz
@ -43,6 +48,10 @@ if [ $? != 0 ]; then
sh setup.sh
cd - > /dev/null
fi
python3 -c "import pkg_resources; pkg_resources.require(\"swig_decoders==1.1\")"
if [ $? != 0 ]; then
error_msg "Please check why decoder install error!"
exit -1
fi
echo "Install all dependencies successfully."
info_msg "Install all dependencies successfully."

@ -0,0 +1,11 @@
_HDR_FMT="%.23s %s[%s]: "
_ERR_MSG_FMT="ERROR: ${_HDR_FMT}%s\n"
_INFO_MSG_FMT="INFO: ${_HDR_FMT}%s\n"
error_msg() {
printf "$_ERR_MSG_FMT" $(date +%F.%T.%N) ${BASH_SOURCE[1]##*/} ${BASH_LINENO[0]} "${@}"
}
info_msg() {
printf "$_INFO_MSG_FMT" $(date +%F.%T.%N) ${BASH_SOURCE[1]##*/} ${BASH_LINENO[0]} "${@}"
}
Loading…
Cancel
Save