diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 000000000..b31d98631 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,42 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: '' +assignees: '' + +--- + +For support and discussions, please use our [Discourse forums](https://github.com/PaddlePaddle/DeepSpeech/discussions). + +If you've found a bug then please create an issue with the following information: + +**Describe the bug** +A clear and concise description of what the bug is. + +**To Reproduce** +Steps to reproduce the behavior: +1. Go to '...' +2. Click on '....' +3. Scroll down to '....' +4. See error + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**Screenshots** +If applicable, add screenshots to help explain your problem. + +** Environment (please complete the following information):** + - OS: [e.g. Ubuntu] + - GCC/G++ Version [e.g. 8.3] + - Python Version [e.g. 3.7] + - PaddlePaddle Version [e.g. 2.0.0] + - Model Version [e.g. 2.0.0] + - GPU/DRIVER Informationo [e.g. Tesla V100-SXM2-32GB/440.64.00] + - CUDA/CUDNN Version [e.g. cuda-10.2] + - MKL Version +- TensorRT Version + +**Additional context** +Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 000000000..94d507035 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,24 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: "[Feature request]" +labels: feature request +assignees: '' + +--- + +For support and discussions, please use our [Discourse forums](https://github.com/PaddlePaddle/DeepSpeech/discussions). + +If you've found a feature request then please create an issue with the following information: + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context or screenshots about the feature request here. diff --git a/.travis/unittest.sh b/.travis/unittest.sh index c152a1bc5..416042c8c 100755 --- a/.travis/unittest.sh +++ b/.travis/unittest.sh @@ -11,6 +11,13 @@ abort(){ unittest(){ cd $1 > /dev/null + if [ -f "setup.sh" ]; then + bash setup.sh + export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH + fi + if [ $? != 0 ]; then + exit 1 + fi find . -path ./tools/venv -prune -false -o -name 'tests' -type d -print0 | \ xargs -0 -I{} -n1 bash -c \ 'python3 -m unittest discover -v -s {}' @@ -19,6 +26,15 @@ unittest(){ coverage(){ cd $1 > /dev/null + + if [ -f "setup.sh" ]; then + bash setup.sh + export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH + fi + if [ $? != 0 ]; then + exit 1 + fi + find . -path ./tools/venv -prune -false -o -name 'tests' -type d -print0 | \ xargs -0 -I{} -n1 bash -c \ 'python3 -m coverage run --branch {}' diff --git a/README.md b/README.md index f18881499..9330c005f 100644 --- a/README.md +++ b/README.md @@ -21,26 +21,13 @@ * python>=3.7 * paddlepaddle>=2.0.0 -- Run the setup script for the remaining dependencies - -```bash -git clone https://github.com/PaddlePaddle/DeepSpeech.git -cd DeepSpeech -pushd tools; make; popd -source tools/venv/bin/activate -bash setup.sh -``` - -- Source venv before do experiment. - -```bash -source tools/venv/bin/activate -``` +Please see [install](docs/install.md). ## Getting Started Please see [Getting Started](docs/src/geting_started.md) and [tiny egs](examples/tiny/README.md). + ## More Information * [Install](docs/src/install.md) @@ -56,7 +43,7 @@ Please see [Getting Started](docs/src/geting_started.md) and [tiny egs](examples ## Questions and Help -You are welcome to submit questions and bug reports in [Github Issues](https://github.com/PaddlePaddle/DeepSpeech/issues). You are also welcome to contribute to this project. +You are welcome to submit questions in [Github Discussions](https://github.com/PaddlePaddle/DeepSpeech/discussions) and bug reports in [Github Issues](https://github.com/PaddlePaddle/DeepSpeech/issues). You are also welcome to contribute to this project. ## License diff --git a/README_cn.md b/README_cn.md index 769130472..b50d205e9 100644 --- a/README_cn.md +++ b/README_cn.md @@ -18,24 +18,11 @@ ## 安装 + * python>=3.7 * paddlepaddle>=2.0.0 -- 安装依赖 - -```bash -git clone https://github.com/PaddlePaddle/DeepSpeech.git -cd DeepSpeech -pushd tools; make; popd -source tools/venv/bin/activate -bash setup.sh -``` - -- 开始实验前要source环境. - -```bash -source tools/venv/bin/activate -``` +参看 [安装](docs/install.md)。 ## 开始 @@ -55,7 +42,7 @@ source tools/venv/bin/activate ## 问题和帮助 -欢迎您在[Github问题](https://github.com/PaddlePaddle/models/issues)中提交问题和bug。也欢迎您为这个项目做出贡献。 +欢迎您在[Github讨论](https://github.com/PaddlePaddle/DeepSpeech/discussions)提交问题,[Github问题](https://github.com/PaddlePaddle/models/issues)中反馈bug。也欢迎您为这个项目做出贡献。 ## License diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 75335d318..e3d6369bb 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -58,8 +58,6 @@ class DeepSpeech2Trainer(Trainer): losses_np = { 'train_loss': float(loss), - 'train_loss_div_batchsize': - float(loss) / self.config.data.batch_size } msg = "Train: Rank: {}, ".format(dist.get_rank()) msg += "epoch: {}, ".format(self.epoch) @@ -85,8 +83,6 @@ class DeepSpeech2Trainer(Trainer): loss = self.model(*batch) valid_losses['val_loss'].append(float(loss)) - valid_losses['val_loss_div_batchsize'].append( - float(loss) / self.config.data.batch_size) # write visual log valid_losses = {k: np.mean(v) for k, v in valid_losses.items()} @@ -265,7 +261,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): self.logger.info(msg) def run_test(self): - self.resume_or_load() + self.resume_or_scratch() try: self.test() except KeyboardInterrupt: diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index b5f90f046..38e24cef5 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -80,12 +80,15 @@ class U2Trainer(Trainer): self.model.train() start = time.time() + loss, attention_loss, ctc_loss = self.model(*batch_data) loss.backward() layer_tools.print_grads(self.model, print_func=None) + if self.iteration % train_conf.accum_grad == 0: self.optimizer.step() self.optimizer.clear_grad() + self.lr_scheduler.step() iteration_time = time.time() - start @@ -102,11 +105,49 @@ class U2Trainer(Trainer): if self.iteration % train_conf.log_interval == 0: self.logger.info(msg) + # display if dist.get_rank() == 0 and self.visualizer: for k, v in losses_np.items(): self.visualizer.add_scalar("train/{}".format(k), v, self.iteration) + def train(self): + """The training process. + It includes forward/backward/update and periodical validation and + saving. + """ + # !!!IMPORTANT!!! + # Try to export the model by script, if fails, we should refine + # the code to satisfy the script export requirements + # script_model = paddle.jit.to_static(self.model) + # script_model_path = str(self.checkpoint_dir / 'init') + # paddle.jit.save(script_model, script_model_path) + + from_scratch = self.resume_or_scratch() + self.logger.info( + f"Train Total Examples: {len(self.train_loader.dataset)}") + + while self.epoch <= self.config.training.n_epoch: + try: + data_start_time = time.time() + for batch in self.train_loader: + dataload_time = time.time() - data_start_time + msg = "Train: Rank: {}, ".format(dist.get_rank()) + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "lr: {}, ".foramt(self.lr_scheduler()) + msg += "dataloader time: {:>.3f}s, ".format(dataload_time) + self.iteration += 1 + self.train_batch(batch, msg) + data_start_time = time.time() + except Exception as e: + self.logger.error(e) + raise e + + self.valid() + self.save() + self.new_epoch() + @mp_tools.rank_zero_only @paddle.no_grad() def valid(self): @@ -365,7 +406,7 @@ class U2Tester(U2Trainer): self.logger.info(msg) def run_test(self): - self.resume_or_load() + self.resume_or_scratch() try: self.test() except KeyboardInterrupt: diff --git a/deepspeech/models/deepspeech2.py b/deepspeech/models/deepspeech2.py index 7cba2f2cb..884fa4b1f 100644 --- a/deepspeech/models/deepspeech2.py +++ b/deepspeech/models/deepspeech2.py @@ -159,7 +159,8 @@ class DeepSpeech2Model(nn.Layer): enc_n_units=self.encoder.output_size, blank_id=dict_size, # last token is dropout_rate=0.0, - reduction=True) + reduction=True, # sum + batch_average=True) # sum / batch_size def forward(self, audio, audio_len, text, text_len): """Compute Model loss diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index 16573a38f..90f3e3227 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -834,7 +834,14 @@ class U2Model(U2BaseModel): decoder = TransformerDecoder(vocab_size, encoder.output_size(), **configs['decoder_conf']) - ctc = CTCDecoder(vocab_size, encoder.output_size()) + ctc = CTCDecoder( + odim=vocab_size, + enc_n_units=encoder.output_size(), + blank_id=0, + dropout_rate=0.0, + reduction=True, # sum + batch_average=True) # sum / batch_size + return vocab_size, encoder, decoder, ctc @classmethod diff --git a/deepspeech/modules/ctc.py b/deepspeech/modules/ctc.py index 64508a74d..1cd7a3c85 100644 --- a/deepspeech/modules/ctc.py +++ b/deepspeech/modules/ctc.py @@ -37,14 +37,16 @@ class CTCDecoder(nn.Layer): enc_n_units, blank_id=0, dropout_rate: float=0.0, - reduction: bool=True): + reduction: bool=True, + batch_average: bool=True): """CTC decoder Args: odim ([int]): text vocabulary size enc_n_units ([int]): encoder output dimention dropout_rate (float): dropout rate (0.0 ~ 1.0) - reduction (bool): reduce the CTC loss into a scalar + reduction (bool): reduce the CTC loss into a scalar, True for 'sum' or 'none' + batch_average (bool): do batch dim wise average. """ assert check_argument_types() super().__init__() @@ -54,7 +56,10 @@ class CTCDecoder(nn.Layer): self.dropout_rate = dropout_rate self.ctc_lo = nn.Linear(enc_n_units, self.odim) reduction_type = "sum" if reduction else "none" - self.criterion = CTCLoss(blank=self.blank_id, reduction=reduction_type) + self.criterion = CTCLoss( + blank=self.blank_id, + reduction=reduction_type, + batch_average=batch_average) # CTCDecoder LM Score handle self._ext_scorer = None diff --git a/deepspeech/modules/loss.py b/deepspeech/modules/loss.py index cb65ba140..95ca644ad 100644 --- a/deepspeech/modules/loss.py +++ b/deepspeech/modules/loss.py @@ -24,32 +24,33 @@ __all__ = ['CTCLoss', "LabelSmoothingLoss"] class CTCLoss(nn.Layer): - def __init__(self, blank=0, reduction='sum'): + def __init__(self, blank=0, reduction='sum', batch_average=False): super().__init__() # last token id as blank id self.loss = nn.CTCLoss(blank=blank, reduction=reduction) + self.batch_average = batch_average def forward(self, logits, ys_pad, hlens, ys_lens): """Compute CTC loss. Args: - logits ([paddle.Tensor]): [description] - ys_pad ([paddle.Tensor]): [description] - hlens ([paddle.Tensor]): [description] - ys_lens ([paddle.Tensor]): [description] + logits ([paddle.Tensor]): [B, Tmax, D] + ys_pad ([paddle.Tensor]): [B, Tmax] + hlens ([paddle.Tensor]): [B] + ys_lens ([paddle.Tensor]): [B] Returns: [paddle.Tensor]: scalar. If reduction is 'none', then (N), where N = \text{batch size}. """ + B = paddle.shape(logits)[0] # warp-ctc need logits, and do softmax on logits by itself # warp-ctc need activation with shape [T, B, V + 1] # logits: (B, L, D) -> (L, B, D) logits = logits.transpose([1, 0, 2]) loss = self.loss(logits, ys_pad, hlens, ys_lens) - - # wenet do batch-size average, deepspeech2 not do this - # Batch-size average - # loss = loss / paddle.shape(logits)[1] + if self.batch_average: + # Batch-size average + loss = loss / B return loss diff --git a/deepspeech/training/scheduler.py b/deepspeech/training/scheduler.py index 54103e061..08e9d4121 100644 --- a/deepspeech/training/scheduler.py +++ b/deepspeech/training/scheduler.py @@ -54,4 +54,4 @@ class WarmupLR(LRScheduler): step_num**-0.5, step_num * self.warmup_steps**-1.5) def set_step(self, step: int): - self.last_epoch = step + self.step(step) diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 982faa989..474f8d728 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -139,7 +139,7 @@ class Trainer(): checkpoint.save_parameters(self.checkpoint_dir, self.iteration, self.model, self.optimizer, infos) - def resume_or_load(self): + def resume_or_scratch(self): """Resume from latest checkpoint at checkpoints in the output directory or load a specified checkpoint. @@ -152,8 +152,20 @@ class Trainer(): checkpoint_dir=self.checkpoint_dir, checkpoint_path=self.args.checkpoint_path) if infos: + # restore from ckpt self.iteration = infos["step"] self.epoch = infos["epoch"] + self.lr_scheduler.step(self.iteration) + if self.parallel: + self.train_loader.batch_sampler.set_epoch(self.epoch) + return False + else: + # from scratch, epoch and iteration init with zero + # save init model, i.e. 0 epoch + self.save() + # self.epoch start from 1. + self.new_epoch() + return True def new_epoch(self): """Reset the train loader and increment ``epoch``. @@ -166,22 +178,22 @@ class Trainer(): def train(self): """The training process. - It includes forward/backward/update and periodical validation and - saving. """ + from_scratch = self.resume_or_scratch() + self.logger.info( f"Train Total Examples: {len(self.train_loader.dataset)}") - self.new_epoch() while self.epoch <= self.config.training.n_epoch: try: data_start_time = time.time() for batch in self.train_loader: dataload_time = time.time() - data_start_time + # iteration start from 1. + self.iteration += 1 msg = "Train: Rank: {}, ".format(dist.get_rank()) msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) msg += "dataloader time: {:>.3f}s, ".format(dataload_time) - self.iteration += 1 self.train_batch(batch, msg) data_start_time = time.time() except Exception as e: @@ -190,6 +202,7 @@ class Trainer(): self.valid() self.save() + # lr control by epoch self.lr_scheduler.step() self.new_epoch() @@ -197,7 +210,6 @@ class Trainer(): """The routine of the experiment after setup. This method is intended to be used by the user. """ - self.resume_or_load() try: self.train() except KeyboardInterrupt: @@ -298,7 +310,7 @@ class Trainer(): # global logger stdout = False - save_path = log_file + save_path = str(log_file) logging.basicConfig( level=logging.DEBUG if stdout else logging.INFO, format=format, diff --git a/docs/src/geting_started.md b/docs/src/getting_started.md similarity index 100% rename from docs/src/geting_started.md rename to docs/src/getting_started.md diff --git a/docs/src/install.md b/docs/src/install.md index 72b7b6988..01049a2fc 100644 --- a/docs/src/install.md +++ b/docs/src/install.md @@ -45,7 +45,7 @@ source tools/venv/bin/activate ## Running in Docker Container (optional) -Docker is an open source tool to build, ship, and run distributed applications in an isolated environment. A Docker image for this project has been provided in [hub.docker.com](https://hub.docker.com) with all the dependencies installed, including the pre-built PaddlePaddle, CTC decoders, and other necessary Python and third-party packages. This Docker image requires the support of NVIDIA GPU, so please make sure its availiability and the [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) has been installed. +Docker is an open source tool to build, ship, and run distributed applications in an isolated environment. A Docker image for this project has been provided in [hub.docker.com](https://hub.docker.com) with all the dependencies installed. This Docker image requires the support of NVIDIA GPU, so please make sure its availiability and the [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) has been installed. Take several steps to launch the Docker image: @@ -79,3 +79,7 @@ For example, for CUDA 10.1, CuDNN7.5 install paddle 2.0.0: ```bash python3 -m pip install paddlepaddle-gpu==2.0.0 ``` + +- Install Deepspeech + +Please see [Setup](#setup) section. diff --git a/docs/src/ngram_lm.md b/docs/src/ngram_lm.md index 48c557ce9..1417d329e 100644 --- a/docs/src/ngram_lm.md +++ b/docs/src/ngram_lm.md @@ -1,6 +1,8 @@ # Prepare Language Model -A language model is required to improve the decoder's performance. We have prepared two language models (with lossy compression) for users to download and try. One is for English and the other is for Mandarin. Users can simply run this to download the preprared language models: +A language model is required to improve the decoder's performance. We have prepared two language models (with lossy compression) for users to download and try. One is for English and the other is for Mandarin. The bash script to download LM is example's `local/download_lm_*.sh`. + +For example, users can simply run this to download the preprared mandarin language models: ```bash cd examples/aishell @@ -8,7 +10,9 @@ source path.sh bash local/download_lm_ch.sh ``` -If you wish to train your own better language model, please refer to [KenLM](https://github.com/kpu/kenlm) for tutorials. Here we provide some tips to show how we preparing our English and Mandarin language models. You can take it as a reference when you train your own. +If you wish to train your own better language model, please refer to [KenLM](https://github.com/kpu/kenlm) for tutorials. +Here we provide some tips to show how we preparing our English and Mandarin language models. +You can take it as a reference when you train your own. ## English LM diff --git a/examples/aishell/.gitignore b/examples/aishell/.gitignore index 389676a70..3c13afe8a 100644 --- a/examples/aishell/.gitignore +++ b/examples/aishell/.gitignore @@ -2,3 +2,4 @@ data ckpt* demo_cache *.log +log diff --git a/examples/aishell/s0/conf/deepspeech2.yaml b/examples/aishell/s0/conf/deepspeech2.yaml index e06ae0239..5a386b985 100644 --- a/examples/aishell/s0/conf/deepspeech2.yaml +++ b/examples/aishell/s0/conf/deepspeech2.yaml @@ -29,8 +29,8 @@ model: use_gru: True share_rnn_weights: False training: - n_epoch: 30 - lr: 5e-4 + n_epoch: 50 + lr: 2e-3 lr_decay: 0.83 weight_decay: 1e-06 global_grad_clip: 5.0 @@ -39,7 +39,7 @@ decoding: error_rate_type: cer decoding_method: ctc_beam_search lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm - alpha: 2.6 + alpha: 1.9 beta: 5.0 beam_size: 300 cutoff_prob: 0.99 diff --git a/examples/aishell/s0/local/infer.sh b/examples/aishell/s0/local/infer.sh index 41ccabf80..8c6a4dca2 100644 --- a/examples/aishell/s0/local/infer.sh +++ b/examples/aishell/s0/local/infer.sh @@ -1,6 +1,6 @@ #! /usr/bin/env bash -if [[ $# != 1 ]]; +if [[ $# != 1 ]]; then echo "usage: $0 ckpt-path" exit -1 fi diff --git a/examples/aishell/s0/local/train.sh b/examples/aishell/s0/local/train.sh index c286566a8..245ed2172 100644 --- a/examples/aishell/s0/local/train.sh +++ b/examples/aishell/s0/local/train.sh @@ -2,7 +2,7 @@ # train model # if you wish to resume from an exists model, uncomment --init_from_pretrained_model -export FLAGS_sync_nccl_allreduce=0 +#export FLAGS_sync_nccl_allreduce=0 ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));') echo "using $ngpu gpus..." diff --git a/examples/aishell/s0/run.sh b/examples/aishell/s0/run.sh index 8beb6bf0f..2e215a999 100644 --- a/examples/aishell/s0/run.sh +++ b/examples/aishell/s0/run.sh @@ -7,7 +7,7 @@ source path.sh bash ./local/data.sh # train model -CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./local/train.sh +CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./local/train.sh baseline # test model CUDA_VISIBLE_DEVICES=0 bash ./local/test.sh @@ -16,4 +16,4 @@ CUDA_VISIBLE_DEVICES=0 bash ./local/test.sh CUDA_VISIBLE_DEVICES=0 bash ./local/infer.sh ckpt/checkpoints/step-3284 # export model -bash ./local/export.sh ckpt/checkpoints/step-3284 jit.model \ No newline at end of file +bash ./local/export.sh ckpt/checkpoints/step-3284 jit.model diff --git a/examples/librispeech/README.md b/examples/librispeech/README.md index e109e1ae4..697cb91d4 100644 --- a/examples/librispeech/README.md +++ b/examples/librispeech/README.md @@ -1 +1 @@ -* s0 for deepspeech2 +* s0 is for deepspeech diff --git a/examples/librispeech/s0/conf/deepspeech2.yaml b/examples/librispeech/s0/conf/deepspeech2.yaml index 81313e611..2be8f78a9 100644 --- a/examples/librispeech/s0/conf/deepspeech2.yaml +++ b/examples/librispeech/s0/conf/deepspeech2.yaml @@ -29,8 +29,8 @@ model: use_gru: False share_rnn_weights: True training: - n_epoch: 20 - lr: 5e-4 + n_epoch: 50 + lr: 1e-3 lr_decay: 0.83 weight_decay: 1e-06 global_grad_clip: 5.0 @@ -39,7 +39,7 @@ decoding: error_rate_type: wer decoding_method: ctc_beam_search lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm - alpha: 2.5 + alpha: 1.9 beta: 0.3 beam_size: 500 cutoff_prob: 1.0 diff --git a/examples/librispeech/s0/local/infer.sh b/examples/librispeech/s0/local/infer.sh index 6fc8d39fc..98b3b016a 100644 --- a/examples/librispeech/s0/local/infer.sh +++ b/examples/librispeech/s0/local/infer.sh @@ -1,6 +1,6 @@ #! /usr/bin/env bash -if [[ $# != 1 ]]; +if [[ $# != 1 ]];then echo "usage: $0 ckpt-path" exit -1 fi diff --git a/examples/librispeech/s0/local/train.sh b/examples/librispeech/s0/local/train.sh index 507947e9e..cbccb1896 100644 --- a/examples/librispeech/s0/local/train.sh +++ b/examples/librispeech/s0/local/train.sh @@ -1,8 +1,9 @@ #! /usr/bin/env bash -export FLAGS_sync_nccl_allreduce=0 +#export FLAGS_sync_nccl_allreduce=0 + # https://github.com/PaddlePaddle/Paddle/pull/28484 -export NCCL_SHM_DISABLE=1 +#export NCCL_SHM_DISABLE=1 ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));') echo "using $ngpu gpus..." @@ -11,7 +12,7 @@ python3 -u ${BIN_DIR}/train.py \ --device 'gpu' \ --nproc ${ngpu} \ --config conf/deepspeech2.yaml \ ---output ckpt +--output ckpt-${1} if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/tiny/s0/local/infer.sh b/examples/tiny/s0/local/infer.sh index 1243c0d08..b36f9000a 100644 --- a/examples/tiny/s0/local/infer.sh +++ b/examples/tiny/s0/local/infer.sh @@ -1,6 +1,6 @@ #! /usr/bin/env bash -if [[ $# != 1 ]]; +if [[ $# != 1 ]];then echo "usage: $0 ckpt-path" exit -1 fi diff --git a/examples/tiny/s0/local/test.sh b/examples/tiny/s0/local/test.sh index a0f200799..8c8c278c6 100644 --- a/examples/tiny/s0/local/test.sh +++ b/examples/tiny/s0/local/test.sh @@ -6,7 +6,6 @@ if [ $? -ne 0 ]; then exit 1 fi -CUDA_VISIBLE_DEVICES=0 \ python3 -u ${BIN_DIR}/test.py \ --device 'gpu' \ --nproc 1 \ diff --git a/examples/tiny/s0/local/train.sh b/examples/tiny/s0/local/train.sh index 369ccc924..af62ae55f 100644 --- a/examples/tiny/s0/local/train.sh +++ b/examples/tiny/s0/local/train.sh @@ -2,7 +2,6 @@ export FLAGS_sync_nccl_allreduce=0 -CUDA_VISIBLE_DEVICES=0 \ python3 -u ${BIN_DIR}/train.py \ --device 'gpu' \ --nproc 1 \ diff --git a/setup.sh b/setup.sh index a58bd7967..5141fd904 100644 --- a/setup.sh +++ b/setup.sh @@ -1,5 +1,8 @@ #! /usr/bin/env bash +source utils/log.sh + + SUDO='sudo' if [ $(id -u) -eq 0 ]; then SUDO='' @@ -8,6 +11,8 @@ fi if [ -e /etc/lsb-release ];then #${SUDO} apt-get update ${SUDO} apt-get install -y sox pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev + error_msg "Please using Ubuntu or install `pkg-config libflac-dev libogg-dev libvorbis-dev libboost-dev swig python3-dev` by user." + exit -1 fi # install python dependencies @@ -15,17 +20,17 @@ if [ -f "requirements.txt" ]; then pip3 install -r requirements.txt fi if [ $? != 0 ]; then - echo "Install python dependencies failed !!!" + error_msg "Install python dependencies failed !!!" exit 1 fi # install package libsndfile python3 -c "import soundfile" if [ $? != 0 ]; then - echo "Install package libsndfile into default system path." + info_msg "Install package libsndfile into default system path." wget "http://www.mega-nerd.com/libsndfile/files/libsndfile-1.0.28.tar.gz" if [ $? != 0 ]; then - echo "Download libsndfile-1.0.28.tar.gz failed !!!" + error_msg "Download libsndfile-1.0.28.tar.gz failed !!!" exit 1 fi tar -zxvf libsndfile-1.0.28.tar.gz @@ -43,6 +48,10 @@ if [ $? != 0 ]; then sh setup.sh cd - > /dev/null fi +python3 -c "import pkg_resources; pkg_resources.require(\"swig_decoders==1.1\")" +if [ $? != 0 ]; then + error_msg "Please check why decoder install error!" + exit -1 +fi - -echo "Install all dependencies successfully." +info_msg "Install all dependencies successfully." diff --git a/utils/log.sh b/utils/log.sh new file mode 100644 index 000000000..84591b076 --- /dev/null +++ b/utils/log.sh @@ -0,0 +1,11 @@ +_HDR_FMT="%.23s %s[%s]: " +_ERR_MSG_FMT="ERROR: ${_HDR_FMT}%s\n" +_INFO_MSG_FMT="INFO: ${_HDR_FMT}%s\n" + +error_msg() { + printf "$_ERR_MSG_FMT" $(date +%F.%T.%N) ${BASH_SOURCE[1]##*/} ${BASH_LINENO[0]} "${@}" +} + +info_msg() { + printf "$_INFO_MSG_FMT" $(date +%F.%T.%N) ${BASH_SOURCE[1]##*/} ${BASH_LINENO[0]} "${@}" +}