Merge pull request #819 from PaddlePaddle/timer

add timer info
pull/822/head
Jackwaterveg 3 years ago committed by GitHub
commit 5e063adf44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -48,9 +48,8 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm):
sum_square_list.append(sum_square) sum_square_list.append(sum_square)
# debug log # debug log
if i < 10: logger.debug(
logger.debug( f"Grad Before Clip: {p.name}: {float(sum_square.sqrt()) }")
f"Grad Before Clip: {p.name}: {float(sum_square.sqrt()) }")
# all parameters have been filterd out # all parameters have been filterd out
if len(sum_square_list) == 0: if len(sum_square_list) == 0:
@ -77,9 +76,8 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm):
params_and_grads.append((p, new_grad)) params_and_grads.append((p, new_grad))
# debug log # debug log
if i < 10: logger.debug(
logger.debug( f"Grad After Clip: {p.name}: {float(new_grad.square().sum().sqrt())}"
f"Grad After Clip: {p.name}: {float(new_grad.square().sum().sqrt())}" )
)
return params_and_grads return params_and_grads

@ -27,7 +27,7 @@ class Timer():
do some thing do some thing
""" """
def __init__(self, message): def __init__(self, message=None):
self.message = message self.message = message
def duration(self) -> str: def duration(self) -> str:
@ -40,7 +40,8 @@ class Timer():
return self return self
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
logger.info(self.message.format(self.duration())) if self.message:
logger.info(self.message.format(self.duration()))
def __call__(self) -> float: def __call__(self) -> float:
return time.time() - self.start return time.time() - self.start

@ -185,46 +185,47 @@ class Trainer():
def train(self): def train(self):
"""The training process control by epoch.""" """The training process control by epoch."""
with Timer("Load/Init Model: {}"): from_scratch = self.resume_or_scratch()
from_scratch = self.resume_or_scratch() if from_scratch:
if from_scratch: # save init model, i.e. 0 epoch
# save init model, i.e. 0 epoch self.save(tag='init', infos=None)
self.save(tag='init', infos=None) self.lr_scheduler.step(self.epoch)
self.lr_scheduler.step(self.epoch) if self.parallel and hasattr(self.train_loader, "batch_sampler"):
if self.parallel and hasattr(self.train_loader, "batch_sampler"): self.train_loader.batch_sampler.set_epoch(self.epoch)
self.train_loader.batch_sampler.set_epoch(self.epoch)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.training.n_epoch:
self.model.train() with Timer("Epoch-Train Time Cost: {}"):
try: self.model.train()
data_start_time = time.time() try:
for batch_index, batch in enumerate(self.train_loader):
dataload_time = time.time() - data_start_time
msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "batch : {}/{}, ".format(batch_index + 1,
len(self.train_loader))
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
msg += "data time: {:>.3f}s, ".format(dataload_time)
self.train_batch(batch_index, batch, msg)
data_start_time = time.time() data_start_time = time.time()
except Exception as e: for batch_index, batch in enumerate(self.train_loader):
logger.error(e) dataload_time = time.time() - data_start_time
raise e msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
total_loss, num_seen_utts = self.valid() msg += "step: {}, ".format(self.iteration)
if dist.get_world_size() > 1: msg += "batch : {}/{}, ".format(batch_index + 1,
num_seen_utts = paddle.to_tensor(num_seen_utts) len(self.train_loader))
# the default operator in all_reduce function is sum. msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
dist.all_reduce(num_seen_utts) msg += "data time: {:>.3f}s, ".format(dataload_time)
total_loss = paddle.to_tensor(total_loss) self.train_batch(batch_index, batch, msg)
dist.all_reduce(total_loss) data_start_time = time.time()
cv_loss = total_loss / num_seen_utts except Exception as e:
cv_loss = float(cv_loss) logger.error(e)
else: raise e
cv_loss = total_loss / num_seen_utts
with Timer("Eval Time Cost: {}"):
total_loss, num_seen_utts = self.valid()
if dist.get_world_size() > 1:
num_seen_utts = paddle.to_tensor(num_seen_utts)
# the default operator in all_reduce function is sum.
dist.all_reduce(num_seen_utts)
total_loss = paddle.to_tensor(total_loss)
dist.all_reduce(total_loss)
cv_loss = total_loss / num_seen_utts
cv_loss = float(cv_loss)
else:
cv_loss = total_loss / num_seen_utts
logger.info( logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))

@ -1,8 +1,8 @@
# Deepspeech2 # Deepspeech2
## Streaming ## Streaming
The implemented arcitecure of Deepspeech2 online model is based on [Deepspeech2 model](https://arxiv.org/pdf/1512.02595.pdf) with some changes. The implemented arcitecure of Deepspeech2 online model is based on [Deepspeech2 model](https://arxiv.org/pdf/1512.02595.pdf) with some changes.
The model is mainly composed of 2D convolution subsampling layer and stacked single direction rnn layers. The model is mainly composed of 2D convolution subsampling layer and stacked single direction rnn layers.
To illustrate the model implementation clearly, 3 parts are described in detail. To illustrate the model implementation clearly, 3 parts are described in detail.
- Data Preparation - Data Preparation
@ -11,10 +11,10 @@ To illustrate the model implementation clearly, 3 parts are described in detail.
In addition, the training process and the testing process are also introduced. In addition, the training process and the testing process are also introduced.
The arcitecture of the model is shown in Fig.1. The arcitecture of the model is shown in Fig.1.
<p align="center"> <p align="center">
<img src="../images/ds2onlineModel.png" width=800> <img src="../images/ds2onlineModel.png" width=800>
<br/>Fig.1 The Arcitecture of deepspeech2 online model <br/>Fig.1 The Arcitecture of deepspeech2 online model
</p> </p>
@ -28,17 +28,17 @@ For English data, the vocabulary dictionary is composed of 26 English characters
--unit_type="char" \ --unit_type="char" \
--count_threshold=0 \ --count_threshold=0 \
--vocab_path="data/vocab.txt" \ --vocab_path="data/vocab.txt" \
--manifest_paths "data/manifest.train.raw" "data/manifest.dev.raw" --manifest_paths "data/manifest.train.raw" "data/manifest.dev.raw"
# vocabulary for aishell dataset (Mandarin) # vocabulary for aishell dataset (Mandarin)
vi examples/aishell/s0/data/vocab.txt vi examples/aishell/s0/data/vocab.txt
# vocabulary for librispeech dataset (English) # vocabulary for librispeech dataset (English)
vi examples/librispeech/s0/data/vocab.txt vi examples/librispeech/s0/data/vocab.txt
``` ```
#### CMVN #### CMVN
For CMVN, a subset or the full of traininig set is chosed and be used to compute the feature mean and std. For CMVN, a subset or the full of traininig set is chosed and be used to compute the feature mean and std.
``` ```
# The code to compute the feature mean and std # The code to compute the feature mean and std
cd examples/aishell/s0 cd examples/aishell/s0
@ -52,16 +52,16 @@ python3 ../../../utils/compute_mean_std.py \
--use_dB_normalization=True \ --use_dB_normalization=True \
--num_samples=2000 \ --num_samples=2000 \
--num_workers=10 \ --num_workers=10 \
--output_path="data/mean_std.json" --output_path="data/mean_std.json"
``` ```
#### Feature Extraction #### Feature Extraction
For feature extraction, three methods are implemented, which are linear (FFT without using filter bank), fbank and mfcc. For feature extraction, three methods are implemented, which are linear (FFT without using filter bank), fbank and mfcc.
Currently, the released deepspeech2 online model use the linear feature extraction method. Currently, the released deepspeech2 online model use the linear feature extraction method.
``` ```
The code for feature extraction The code for feature extraction
vi deepspeech/frontend/featurizer/audio_featurizer.py vi deepspeech/frontend/featurizer/audio_featurizer.py
``` ```
### Encoder ### Encoder
@ -70,7 +70,7 @@ The code of Encoder is in:
``` ```
vi deepspeech/models/ds2_online/deepspeech2.py vi deepspeech/models/ds2_online/deepspeech2.py
``` ```
### Decoder ### Decoder
To got the character possibilities of each frame, the feature represention of each frame output from the backbone are input into a projection layer which is implemented as a dense layer to do projection. The output dim of the projection layer is same with the vocabulary size. After projection layer, the softmax function is used to make frame-level feature representation be the possibilities of characters. While making model inference, the character possibilities of each frame are input into the CTC decoder to get the final speech recognition results. To got the character possibilities of each frame, the feature represention of each frame output from the backbone are input into a projection layer which is implemented as a dense layer to do projection. The output dim of the projection layer is same with the vocabulary size. After projection layer, the softmax function is used to make frame-level feature representation be the possibilities of characters. While making model inference, the character possibilities of each frame are input into the CTC decoder to get the final speech recognition results.
The code of Encoder is in: The code of Encoder is in:
@ -78,7 +78,7 @@ The code of Encoder is in:
vi deepspeech/models/ds2_online/deepspeech2.py vi deepspeech/models/ds2_online/deepspeech2.py
vi deepspeech/modules/ctc.py vi deepspeech/modules/ctc.py
``` ```
## Training Process ## Training Process
Using the command below, you can train the deepspeech2 online model. Using the command below, you can train the deepspeech2 online model.
``` ```
@ -120,7 +120,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
fi fi
``` ```
By using the command above, the training process can be started. There are 5 stages in run.sh, and the first 3 stages are used for training process. The stage 0 is used for data preparation, in which the dataset will be downloaded, and the manifest files of the datasets, vocabulary dictionary and CMVN file will be generated in "./data/". The stage 1 is used for training the model, the log files and model checkpoint is saved in "exp/deepspeech2_online/". The stage 2 is used to generated final model for predicting by averaging the top-k model parameters based on validation loss. By using the command above, the training process can be started. There are 5 stages in run.sh, and the first 3 stages are used for training process. The stage 0 is used for data preparation, in which the dataset will be downloaded, and the manifest files of the datasets, vocabulary dictionary and CMVN file will be generated in "./data/". The stage 1 is used for training the model, the log files and model checkpoint is saved in "exp/deepspeech2_online/". The stage 2 is used to generated final model for predicting by averaging the top-k model parameters based on validation loss.
## Testing Process ## Testing Process
Using the command below, you can test the deepspeech2 online model. Using the command below, you can test the deepspeech2 online model.
``` ```
@ -129,7 +129,7 @@ Using the command below, you can test the deepspeech2 online model.
The detail commands are: The detail commands are:
``` ```
conf_path=conf/deepspeech2_online.yaml conf_path=conf/deepspeech2_online.yaml
avg_num=1 avg_num=1
model_type=online model_type=online
avg_ckpt=avg_${avg_num} avg_ckpt=avg_${avg_num}
@ -150,29 +150,29 @@ fi
``` ```
After the training process, we use stage 3,4,5 for testing process. The stage 3 is for testing the model generated in the stage 2 and provided the CER index of the test set. The stage 4 is for transforming the model from dynamic graph to static graph by using "paddle.jit" library. The stage 5 is for testing the model in static graph. After the training process, we use stage 3,4,5 for testing process. The stage 3 is for testing the model generated in the stage 2 and provided the CER index of the test set. The stage 4 is for transforming the model from dynamic graph to static graph by using "paddle.jit" library. The stage 5 is for testing the model in static graph.
## Non-Streaming ## Non-Streaming
The deepspeech2 offline model is similarity to the deepspeech2 online model. The main difference between them is the offline model use the bi-directional rnn layers while the online model use the single direction rnn layers and the fc layer is not used. The deepspeech2 offline model is similarity to the deepspeech2 online model. The main difference between them is the offline model use the bi-directional rnn layers while the online model use the single direction rnn layers and the fc layer is not used.
The arcitecture of the model is shown in Fig.2. The arcitecture of the model is shown in Fig.2.
<p align="center"> <p align="center">
<img src="../images/ds2offlineModel.png" width=800> <img src="../images/ds2offlineModel.png" width=800>
<br/>Fig.2 The Arcitecture of deepspeech2 offline model <br/>Fig.2 The Arcitecture of deepspeech2 offline model
</p> </p>
For data preparation, decoder, the deepspeech2 offline model is same with the deepspeech2 online model. For data preparation, decoder, the deepspeech2 offline model is same with the deepspeech2 online model.
The code of encoder and decoder for deepspeech2 offline model is in: The code of encoder and decoder for deepspeech2 offline model is in:
``` ```
vi deepspeech/models/ds2/deepspeech2.py vi deepspeech/models/ds2/deepspeech2.py
``` ```
The training process and testing process of deepspeech2 offline model is very similary to deepspeech2 online model. The training process and testing process of deepspeech2 offline model is very similary to deepspeech2 online model.
Only some changes should be noticed. Only some changes should be noticed.
For training and testing, the "model_type" and the "conf_path" must be set. For training and testing, the "model_type" and the "conf_path" must be set.
``` ```
# Training offline # Training offline
cd examples/aishell/s0 cd examples/aishell/s0
@ -184,4 +184,4 @@ cd examples/aishell/s0
bash run.sh --stage 3 --stop_stage 5 --model_type offline --conf_path conf/deepspeech2.yaml bash run.sh --stage 3 --stop_stage 5 --model_type offline --conf_path conf/deepspeech2.yaml
``` ```

Loading…
Cancel
Save