support wav2vec2ASR on librispeech

pull/2518/head
tianhao zhang 2 years ago
parent c9b0c96b7b
commit 6e429f0513

@ -0,0 +1,191 @@
# Wav2vec2ASR with Librispeech
This example contains code used to finetune [wav2vec2.0](https://https://arxiv.org/pdf/2006.11477.pdf) model with [Librispeech dataset](http://www.openslr.org/resources/12)
## Overview
All the scripts you need are in `run.sh`. There are several stages in `run.sh`, and each stage has its function.
| Stage | Function |
|:---- |:----------------------------------------------------------- |
| 0 | Process data. It includes: <br> (1) Download the dataset <br> (2) Calculate the CMVN of the train dataset <br> (3) Get the vocabulary file <br> (4) Get the manifest files of the train, development and test dataset<br> (5) Download the pretrained wav2vec2 model |
| 1 | Train the model |
| 2 | Get the final model by averaging the top-k models, set k = 1 means to choose the best model |
| 3 | Test the final model performance |
| 4 | Infer the single audio file |
You can choose to run a range of stages by setting `stage` and `stop_stage `.
For example, if you want to execute the code in stage 2 and stage 3, you can run this script:
```bash
bash run.sh --stage 2 --stop_stage 3
```
Or you can set `stage` equal to `stop-stage` to only run one stage.
For example, if you only want to run `stage 0`, you can use the script below:
```bash
bash run.sh --stage 0 --stop_stage 0
```
The document below will describe the scripts in `run.sh` in detail.
## The Environment Variables
The path.sh contains the environment variables.
```bash
. ./path.sh
. ./cmd.sh
```
This script needs to be run first. And another script is also needed:
```bash
source ${MAIN_ROOT}/utils/parse_options.sh
```
It will support the way of using `--variable value` in the shell scripts.
## The Local Variables
Some local variables are set in `run.sh`.
`gpus` denotes the GPU number you want to use. If you set `gpus=`, it means you only use CPU.
`stage` denotes the number of stages you want to start from in the experiments.
`stop stage` denotes the number of the stage you want to end at in the experiments.
`conf_path` denotes the config path of the model.
`avg_num` denotes the number K of top-K models you want to average to get the final model.
`audio file` denotes the file path of the single file you want to infer in stage 5
`ckpt` denotes the checkpoint prefix of the model, e.g. "wav2vec2ASR"
You can set the local variables (except `ckpt`) when you use `run.sh`
For example, you can set the `gpus` and `avg_num` when you use the command line:
```bash
bash run.sh --gpus 0,1 --avg_num 20
```
## Stage 0: Data Processing
To use this example, you need to process data firstly and you can use stage 0 in `run.sh` to do this. The code is shown below:
```bash
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
bash ./local/data.sh || exit -1
fi
```
Stage 0 is for processing the data.
If you only want to process the data. You can run
```bash
bash run.sh --stage 0 --stop_stage 0
```
You can also just run these scripts in your command line.
```bash
. ./path.sh
. ./cmd.sh
bash ./local/data.sh
```
After processing the data, the `data` directory will look like this:
```bash
data/
|-- dev.meta
|-- lang_char
| `-- bpe_unigram_5000.model
| `-- bpe_unigram_5000.vocab
| `-- vocab.txt
|-- manifest.dev
|-- manifest.dev.raw
|-- manifest.test
|-- manifest.test.raw
|-- manifest.train
|-- manifest.train.raw
|-- mean_std.json
|-- test.meta
`-- train.meta
```
## Stage 1: Model Training
If you want to train the model. you can use stage 1 in `run.sh`. The code is shown below.
```bash
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt}
fi
```
If you want to train the model, you can use the script below to execute stage 0 and stage 1:
```bash
bash run.sh --stage 0 --stop_stage 1
```
or you can run these scripts in the command line (only use CPU).
```bash
. ./path.sh
. ./cmd.sh
bash ./local/data.sh
CUDA_VISIBLE_DEVICES= ./local/train.sh conf/wav2vec2ASR.yaml wav2vec2ASR
```
## Stage 2: Top-k Models Averaging
After training the model, we need to get the final model for testing and inference. In every epoch, the model checkpoint is saved, so we can choose the best model from them based on the validation loss or we can sort them and average the parameters of the top-k models to get the final model. We can use stage 2 to do this, and the code is shown below. Note: We only train one epoch for wav2vec2ASR, thus the `avg_num` is set to 1.
```bash
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model
avg.sh best exp/${ckpt}/checkpoints ${avg_num}
fi
```
The `avg.sh` is in the `../../../utils/` which is define in the `path.sh`.
If you want to get the final model, you can use the script below to execute stage 0, stage 1, and stage 2:
```bash
bash run.sh --stage 0 --stop_stage 2
```
or you can run these scripts in the command line (only use CPU).
```bash
. ./path.sh
. ./cmd.sh
bash ./local/data.sh
CUDA_VISIBLE_DEVICES= ./local/train.sh conf/wav2vec2ASR.yaml wav2vec2ASR
avg.sh best exp/wav2vec2ASR/checkpoints 1
```
## Stage 3: Model Testing
The test stage is to evaluate the model performance. The code of test stage is shown below:
```bash
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
```
If you want to train a model and test it, you can use the script below to execute stage 0, stage 1, stage 2, and stage 3 :
```bash
bash run.sh --stage 0 --stop_stage 3
```
or you can run these scripts in the command line (only use CPU).
```bash
. ./path.sh
. ./cmd.sh
bash ./local/data.sh
CUDA_VISIBLE_DEVICES= ./local/train.sh conf/wav2vec2ASR.yaml wav2vec2ASR
avg.sh best exp/wav2vec2ASR/checkpoints 1
CUDA_VISIBLE_DEVICES= ./local/test.sh conf/wav2vec2ASR.yaml conf/tuning/decode.yaml exp/wav2vec2ASR/checkpoints/avg_1
```
## Pretrained Model
You can get the pretrained wav2vec2ASR from [this](../../../docs/source/released_model.md).
using the `tar` scripts to unpack the model and then you can use the script to test the model.
For example:
```bash
wget https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/wav2vec2ASR-large-960h-librispeech_ckpt_1.3.0.model.tar.gz
tar xzvf wav2vec2ASR-large-960h-librispeech_ckpt_1.3.0.model.tar.gz
source path.sh
# If you have process the data and get the manifest file you can skip the following 2 steps
bash local/data.sh --stage -1 --stop_stage -1
bash local/data.sh --stage 2 --stop_stage 2
CUDA_VISIBLE_DEVICES= ./local/test.sh conf/wav2vec2ASR.yaml conf/tuning/decode.yaml exp/wav2vec2ASR/checkpoints/avg_1
```
The performance of the released models are shown in [here](./RESULTS.md).
## Stage 4: Single Audio File Inference
In some situations, you want to use the trained model to do the inference for the single audio file. You can use stage 5. The code is shown below
```bash
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# test a single .wav file
CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1
fi
```
you can train the model by yourself using ```bash run.sh --stage 0 --stop_stage 3```, or you can download the pretrained model through the script below:
```bash
wget https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/wav2vec2ASR-large-960h-librispeech_ckpt_1.3.0.model.tar.gz
tar xzvf wav2vec2ASR-large-960h-librispeech_ckpt_1.3.0.model.tar.gz
```
You can download the audio demo:
```bash
wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/en/demo_002_en.wav -P data/
```
You need to prepare an audio file or use the audio demo above, please confirm the sample rate of the audio is 16K. You can get the result of the audio demo by running the script below.
```bash
CUDA_VISIBLE_DEVICES= ./local/test_wav.sh conf/wav2vec2ASR.yaml conf/tuning/decode.yaml exp/wav2vec2ASR/checkpoints/avg_1 data/demo_002_en.wav
```

@ -0,0 +1,8 @@
# LibriSpeech
## Wav2VecASR
train: Epoch 1, 1*V100-32G, batchsize:10
| Model | Params | Config | Augmentation| Test set | Decode method | WER |
| --- | --- | --- | --- | --- | --- | --- |
| wav2vec2ASR | 302.86 M | conf/wav2vec2ASR.yaml | spec_aug | test-clean | greedy search | 0.018887 |

@ -0,0 +1,89 @@
# ====== About run.pl, queue.pl, slurm.pl, and ssh.pl ======
# Usage: <cmd>.pl [options] JOB=1:<nj> <log> <command...>
# e.g.
# run.pl --mem 4G JOB=1:10 echo.JOB.log echo JOB
#
# Options:
# --time <time>: Limit the maximum time to execute.
# --mem <mem>: Limit the maximum memory usage.
# -max-jobs-run <njob>: Limit the number parallel jobs. This is ignored for non-array jobs.
# --num-threads <ngpu>: Specify the number of CPU core.
# --gpu <ngpu>: Specify the number of GPU devices.
# --config: Change the configuration file from default.
#
# "JOB=1:10" is used for "array jobs" and it can control the number of parallel jobs.
# The left string of "=", i.e. "JOB", is replaced by <N>(Nth job) in the command and the log file name,
# e.g. "echo JOB" is changed to "echo 3" for the 3rd job and "echo 8" for 8th job respectively.
# Note that the number must start with a positive number, so you can't use "JOB=0:10" for example.
#
# run.pl, queue.pl, slurm.pl, and ssh.pl have unified interface, not depending on its backend.
# These options are mapping to specific options for each backend and
# it is configured by "conf/queue.conf" and "conf/slurm.conf" by default.
# If jobs failed, your configuration might be wrong for your environment.
#
#
# The official documentation for run.pl, queue.pl, slurm.pl, and ssh.pl:
# "Parallelization in Kaldi": http://kaldi-asr.org/doc/queue.html
# =========================================================~
# Select the backend used by run.sh from "local", "sge", "slurm", or "ssh"
cmd_backend='local'
# Local machine, without any Job scheduling system
if [ "${cmd_backend}" = local ]; then
# The other usage
export train_cmd="run.pl"
# Used for "*_train.py": "--gpu" is appended optionally by run.sh
export cuda_cmd="run.pl"
# Used for "*_recog.py"
export decode_cmd="run.pl"
# "qsub" (SGE, Torque, PBS, etc.)
elif [ "${cmd_backend}" = sge ]; then
# The default setting is written in conf/queue.conf.
# You must change "-q g.q" for the "queue" for your environment.
# To know the "queue" names, type "qhost -q"
# Note that to use "--gpu *", you have to setup "complex_value" for the system scheduler.
export train_cmd="queue.pl"
export cuda_cmd="queue.pl"
export decode_cmd="queue.pl"
# "sbatch" (Slurm)
elif [ "${cmd_backend}" = slurm ]; then
# The default setting is written in conf/slurm.conf.
# You must change "-p cpu" and "-p gpu" for the "partion" for your environment.
# To know the "partion" names, type "sinfo".
# You can use "--gpu * " by default for slurm and it is interpreted as "--gres gpu:*"
# The devices are allocated exclusively using "${CUDA_VISIBLE_DEVICES}".
export train_cmd="slurm.pl"
export cuda_cmd="slurm.pl"
export decode_cmd="slurm.pl"
elif [ "${cmd_backend}" = ssh ]; then
# You have to create ".queue/machines" to specify the host to execute jobs.
# e.g. .queue/machines
# host1
# host2
# host3
# Assuming you can login them without any password, i.e. You have to set ssh keys.
export train_cmd="ssh.pl"
export cuda_cmd="ssh.pl"
export decode_cmd="ssh.pl"
# This is an example of specifying several unique options in the JHU CLSP cluster setup.
# Users can modify/add their own command options according to their cluster environments.
elif [ "${cmd_backend}" = jhu ]; then
export train_cmd="queue.pl --mem 2G"
export cuda_cmd="queue-freegpu.pl --mem 2G --gpu 1 --config conf/gpu.conf"
export decode_cmd="queue.pl --mem 4G"
else
echo "$0: Error: Unknown cmd_backend=${cmd_backend}" 1>&2
return 1
fi

@ -0,0 +1,4 @@
process:
# extract kaldi fbank from PCM
- type: wav_process
dither: 0.1

@ -0,0 +1,11 @@
decode_batch_size: 1
error_rate_type: wer
decoding_method: ctc_greedy_search # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
beam_size: 10
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: False # simulate streaming inference. Defaults to False.

@ -0,0 +1,120 @@
############################################
# Network Architecture #
############################################
freeze_wav2vec2: False
normalize_wav: True
output_norm: True
dnn_blocks: 2
dnn_neurons: 1024
blank_id: 0
ctc_dropout_rate: 0.0
wav2vec2_params_path: "exp/wav2vec2/wav2vec2-large-960h-lv60-self.pdparams"
############################################
# Wav2Vec2.0 #
############################################
vocab_size: 32
hidden_size: 1024
num_hidden_layers: 24
num_attention_heads: 16
intermediate_size: 4096
hidden_act: "gelu"
hidden_dropout: 0.1
activation_dropout: 0.1
attention_dropout: 0.1
feat_proj_dropout: 0.1
feat_quantizer_dropout: 0.0
final_dropout: 0.1
layerdrop: 0.1
initializer_range: 0.02
layer_norm_eps: 1e-5
feat_extract_norm: "layer"
feat_extract_activation: "gelu"
conv_dim: [512, 512, 512, 512, 512, 512, 512]
conv_stride: [5, 2, 2, 2, 2, 2, 2]
conv_kernel: [10, 3, 3, 3, 3, 2, 2]
conv_bias: True
num_conv_pos_embeddings: 128
num_conv_pos_embedding_groups: 16
do_stable_layer_norm: True
apply_spec_augment: False
mask_time_prob: 0.05
mask_time_length: 10
mask_time_min_masks: 2
mask_feature_prob: 0.0
mask_feature_length: 10
mask_feature_min_masks: 0
num_codevectors_per_group: 320
num_codevector_groups: 2
contrastive_logits_temperature: 0.1
num_negatives: 100
codevector_dim: 256
proj_codevector_dim: 256
diversity_loss_weight: 0.1
ctc_loss_reduction: "sum"
ctc_zero_infinity: False
use_weighted_layer_sum: False
pad_token_id: 0
bos_token_id: 1
eos_token_id: 2
add_adapter: False
adapter_kernel_size: 3
adapter_stride: 2
num_adapter_layers: 3
output_hidden_size: None
###########################################
# Data #
###########################################
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test-clean
###########################################
# Dataloader #
###########################################
vocab_filepath: data/lang_char/vocab.txt
unit_type: 'char'
mean_std_filepath: ""
preprocess_config: conf/preprocess.yaml
sortagrad: -1 # Feed samples from shortest to longest ; -1: enabled for all epochs 0: disabled other: enabled for 'other' epochs
batch_size: 10 # Different batch_size may cause large differences in results
maxlen_in: 51200000000 # if input length > maxlen-in batchsize is automatically reduced
maxlen_out: 1500000 # if output length > maxlen-out batchsize is automatically reduced
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
num_workers: 0
subsampling_factor: 1
num_encs: 1
dist_sampler: True
shortest_first: True
return_lens_rate: True
###########################################
# Training #
###########################################
n_epoch: 1
accum_grad: 1
global_grad_clip: 3.0
model_optim: adadelta
model_optim_conf:
lr: 0.9
epsilon: 1.0e-6
rho: 0.95
scheduler: constantlr
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 1
checkpoint:
kbest_n: 50
latest_n: 5
augment: True

@ -0,0 +1,84 @@
#!/bin/bash
set -e
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
expdir=exp
datadir=data
train_set=train_960
recog_set="test-clean test-other dev-clean dev-other"
recog_set="test-clean"
config_path=$1
decode_config_path=$2
ckpt_prefix=$3
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
# download language model
#bash local/download_lm_en.sh
#if [ $? -ne 0 ]; then
# exit 1
#fi
python3 utils/format_rsl.py \
--origin_ref data/manifest.test-clean.raw \
--trans_ref data/manifest.test-clean.text
for type in ctc_greedy_search; do
echo "decoding ${type}"
batch_size=16
python3 -u ${BIN_DIR}/test.py \
--ngpu ${ngpu} \
--config ${config_path} \
--decode_cfg ${decode_config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decode.decoding_method ${type} \
--opts decode.decode_batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
python3 utils/format_rsl.py \
--origin_hyp ${ckpt_prefix}.${type}.rsl \
--trans_hyp ${ckpt_prefix}.${type}.rsl.text
python3 utils/compute-wer.py --char=1 --v=1 \
data/manifest.test-clean.text ${ckpt_prefix}.${type}.rsl.text > ${ckpt_prefix}.${type}.error
echo "decoding ${type} done."
done
for type in ctc_prefix_beam_search; do
echo "decoding ${type}"
batch_size=1
python3 -u ${BIN_DIR}/test.py \
--ngpu ${ngpu} \
--config ${config_path} \
--decode_cfg ${decode_config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decode.decoding_method ${type} \
--opts decode.decode_batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
python3 utils/format_rsl.py \
--origin_hyp ${ckpt_prefix}.${type}.rsl \
--trans_hyp ${ckpt_prefix}.${type}.rsl.text
python3 utils/compute-wer.py --char=1 --v=1 \
data/manifest.test-clean.text ${ckpt_prefix}.${type}.rsl.text > ${ckpt_prefix}.${type}.error
echo "decoding ${type} done."
done
echo "Finished"
exit 0

@ -0,0 +1,58 @@
#!/bin/bash
if [ $# != 4 ];then
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix audio_file"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
decode_config_path=$2
ckpt_prefix=$3
audio_file=$4
mkdir -p data
wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/en/demo_002_en.wav -P data/
if [ $? -ne 0 ]; then
exit 1
fi
if [ ! -f ${audio_file} ]; then
echo "Plase input the right audio_file path"
exit 1
fi
chunk_mode=false
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
chunk_mode=true
fi
# download language model
#bash local/download_lm_ch.sh
#if [ $? -ne 0 ]; then
# exit 1
#fi
for type in ctc_greedy_search; do
echo "decoding ${type}"
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
python3 -u ${BIN_DIR}/test_wav.py \
--ngpu ${ngpu} \
--config ${config_path} \
--decode_cfg ${decode_config_path} \
--result_file ${output_dir}/${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decode.decoding_method ${type} \
--opts decode.decode_batch_size ${batch_size} \
--audio_file ${audio_file}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
exit 0

@ -0,0 +1,55 @@
#!/bin/bash
if [ $# -lt 2 ] && [ $# -gt 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
ckpt_name=$2
ips=$3
if [ ! $ips ];then
ips_config=
else
ips_config="--ips="${ips}
fi
mkdir -p exp
# seed may break model convergence
seed=1998
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True
fi
# export FLAGS_cudnn_exhaustive_search=true
# export FLAGS_conv_workspace_size_limit=4000
export FLAGS_allocator_strategy=naive_best_fit
if [ ${ngpu} == 0 ]; then
python3 -u ${BIN_DIR}/train.py \
--ngpu ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name} \
--seed ${seed}
else
python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
--ngpu ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name} \
--seed ${seed}
fi
if [ ${seed} != 0 ]; then
unset FLAGS_cudnn_deterministic
fi
if [ $? -ne 0 ]; then
echo "Failed in training!"
exit 1
fi
exit 0

@ -0,0 +1,15 @@
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/tools/sctk/bin:${PWD}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=wav2vec2
export BIN_DIR=${MAIN_ROOT}/paddlespeech/s2t/exps/${MODEL}/bin

@ -0,0 +1,48 @@
#!/bin/bash
set -e
. ./path.sh || exit 1;
. ./cmd.sh || exit 1;
gpus=0
stage=0
stop_stage=0
conf_path=conf/wav2vec2ASR.yaml
ips= #xx.xx.xx.xx,xx.xx.xx.xx
decode_conf_path=conf/tuning/decode.yaml
avg_num=1
dict_path=data/lang_char/vocab.txt
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
audio_file=data/demo_002_en.wav
avg_ckpt=avg_${avg_num}
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
echo "checkpoint name ${ckpt}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
bash ./local/data.sh || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model
avg.sh best exp/${ckpt}/checkpoints ${avg_num}
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# attetion resocre decoder
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# test a single .wav file
CUDA_VISIBLE_DEVICES=${gpus} ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1
fi

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,66 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation for wav2vec2.0 model."""
import cProfile
from yacs.config import CfgNode
from paddlespeech.s2t.exps.wav2vec2.model import Wav2Vec2ASRTester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments
# TODO(hui zhang): dynamic load
def main_sp(config, args):
exp = Tester(config, args)
with exp.eval():
exp.setup()
exp.run_test()
def main(config, args):
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
'--dict-path', type=str, default=None, help='dict path.')
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
# Setting for profiling
pr = cProfile.Profile()
pr.runcall(main, config, args)
pr.dump_stats('test.profile')

@ -0,0 +1,118 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation for wav2vec2.0 model."""
import os
import sys
from pathlib import Path
import paddle
import soundfile
from yacs.config import CfgNode
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR import Wav2vec2ASR
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import UpdateConfig
logger = Log(__name__).getlog()
class Wav2vec2Infer():
def __init__(self, config, args):
self.args = args
self.config = config
self.audio_file = args.audio_file
self.text_feature = TextFeaturizer(
unit_type=config.unit_type,
vocab=config.vocab_filepath)
paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu')
# model
model_conf = config
with UpdateConfig(model_conf):
model_conf.output_dim = self.text_feature.vocab_size
model = Wav2vec2ASR.from_config(model_conf)
self.model = model
self.model.eval()
# load model
params_path = self.args.checkpoint_path + ".pdparams"
model_dict = paddle.load(params_path)
self.model.set_state_dict(model_dict)
def run(self):
check(args.audio_file)
with paddle.no_grad():
# read
audio, _ = soundfile.read(
self.audio_file, dtype="int16", always_2d=True)
logger.info(f"audio shape: {audio.shape}")
xs = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0)
decode_config = self.config.decode
result_transcripts, result_tokenids = self.model.decode(
xs,
text_feature=self.text_feature,
decoding_method=decode_config.decoding_method,
beam_size=decode_config.beam_size)
rsl = result_transcripts[0]
utt = Path(self.audio_file).name
logger.info(f"hyp: {utt} {rsl}")
return rsl
def check(audio_file):
if not os.path.isfile(audio_file):
print("Please input the right audio file path")
sys.exit(-1)
logger.info("checking the audio file format......")
try:
sig, sample_rate = soundfile.read(audio_file)
except Exception as e:
logger.error(str(e))
logger.error(
"can not open the wav file, please check the audio file format")
sys.exit(-1)
logger.info("The sample rate is %d" % sample_rate)
assert (sample_rate == 16000)
logger.info("The audio file format is right")
def main(config, args):
Wav2vec2Infer(config, args).run()
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
parser.add_argument(
"--audio_file", type=str, help="path of the input audio file")
args = parser.parse_args()
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
main(config, args)

@ -0,0 +1,54 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trainer for wav2vec2.0 model."""
import cProfile
import os
from yacs.config import CfgNode
from paddlespeech.s2t.exps.wav2vec2.model import Wav2Vec2ASRTrainer as Trainer
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments
def main_sp(config, args):
exp = Trainer(config, args)
exp.setup()
exp.run()
def main(config, args):
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
args = parser.parse_args()
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
# Setting for profiling
pr = cProfile.Profile()
pr.runcall(main, config, args)
pr.dump_stats(os.path.join(args.output, 'train.profile'))

@ -0,0 +1,435 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains wav2vec2 model."""
import json
import os
import time
from collections import defaultdict
from collections import OrderedDict
from contextlib import nullcontext
from paddlespeech.s2t.utils import mp_tools
import jsonlines
import numpy as np
import paddle
from paddle import distributed as dist
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import BatchDataLoader
from paddlespeech.s2t.io.dataloader import StreamDataLoader
from paddlespeech.s2t.io.dataloader import DataLoaderFactory
from paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR import Wav2vec2ASR
from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugment
from paddlespeech.s2t.utils import error_rate
from paddlespeech.s2t.training.optimizer import OptimizerFactory
from paddlespeech.s2t.training.reporter import ObsScope
from paddlespeech.s2t.training.reporter import report
from paddlespeech.s2t.training.scheduler import LRSchedulerFactory
from paddlespeech.s2t.training.timer import Timer
from paddlespeech.s2t.training.trainer import Trainer
from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.s2t.utils import layer_tools
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
class Wav2Vec2ASRTrainer(Trainer):
def __init__(self, config, args):
super().__init__(config, args)
self.avg_train_loss = 0
def train_batch(self, batch_index, batch, msg):
train_conf = self.config
start = time.time()
# forward
utt, wav, wavs_lens, target, target_lens = batch
wavs_lens_rate = wavs_lens / wav.shape[1]
target_lens_rate = target_lens / target.shape[1]
wav = wav[:,:,0]
wav = self.speech_augmentation(wav, wavs_lens_rate)
loss = self.model(wav, wavs_lens_rate, target, target_lens_rate)
# pring(wav, wavs_lens_rate, target, target_lens_rate)
# loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad
losses_np = {'loss': float(loss) * train_conf.accum_grad}
# loss backward
if (batch_index + 1) % train_conf.accum_grad != 0:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
# When using cpu w/o DDP, model does not have `no_sync`
context = self.model.no_sync if (hasattr(self.model, "no_sync") and
self.parallel) else nullcontext
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
context = nullcontext
with context():
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
# optimizer step old
if (batch_index + 1) % train_conf.accum_grad == 0:
self.optimizer.step()
self.optimizer.clear_grad()
self.lr_scheduler.step()
self.iteration += 1
iteration_time = time.time() - start
for k, v in losses_np.items():
report(k, v)
report("batch_size", self.config.batch_size)
report("accum", train_conf.accum_grad)
report("step_cost", iteration_time)
if (batch_index + 1) % train_conf.accum_grad == 0:
if dist.get_rank() == 0 and self.visualizer:
losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()})
for key, val in losses_np_v.items():
self.visualizer.add_scalar(
tag='train/' + key, value=val, step=self.iteration - 1)
@paddle.no_grad()
def valid(self):
self.model.eval()
if not self.use_streamdata:
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list)
num_seen_utts = 1
total_loss = 0.0
for i, batch in enumerate(self.valid_loader):
utt, wav, wavs_lens, target, target_lens = batch
wavs_lens_rate = wavs_lens / wav.shape[1]
target_lens_rate = target_lens / target.shape[1]
wav = wav[:,:,0]
loss = self.model(wav, wavs_lens_rate, target, target_lens_rate)
if paddle.isfinite(loss):
num_utts = batch[1].shape[0]
num_seen_utts += num_utts
total_loss += float(loss) * num_utts
valid_losses['val_loss'].append(float(loss))
if (i + 1) % self.config.log_interval == 0:
valid_dump = {k: np.mean(v) for k, v in valid_losses.items()}
valid_dump['val_history_loss'] = total_loss / num_seen_utts
# logging
msg = f"Valid: Rank: {dist.get_rank()}, "
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata:
msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items())
logger.info(msg)
logger.info('Rank {} Val info val_loss {}'.format(
dist.get_rank(), total_loss / num_seen_utts))
return total_loss, num_seen_utts
def do_train(self):
"""The training process control by step."""
# !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine
# the code to satisfy the script export requirements
# script_model = paddle.jit.to_static(self.model)
# script_model_path = str(self.checkpoint_dir / 'init')
# paddle.jit.save(script_model, script_model_path)
self.before_train()
if not self.use_streamdata:
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"):
self.model.train()
try:
data_start_time = time.time()
for batch_index, batch in enumerate(self.train_loader):
dataload_time = time.time() - data_start_time
msg = "Train:"
observation = OrderedDict()
with ObsScope(observation):
report("Rank", dist.get_rank())
report("epoch", self.epoch)
report('step', self.iteration)
report("lr", self.lr_scheduler())
self.train_batch(batch_index, batch, msg)
self.after_train_batch()
report('iter', batch_index + 1)
if not self.use_streamdata:
report('total', len(self.train_loader))
report('reader_cost', dataload_time)
observation['batch_cost'] = observation[
'reader_cost'] + observation['step_cost']
observation['samples'] = observation['batch_size']
observation['ips,samples/s'] = observation[
'batch_size'] / observation['batch_cost']
for k, v in observation.items():
msg += f" {k.split(',')[0]}: "
msg += f"{v:>.8f}" if isinstance(v,
float) else f"{v}"
msg += f" {k.split(',')[1]}" if len(
k.split(',')) == 2 else ""
msg += ","
msg = msg[:-1] # remove the last ","
if (batch_index + 1) % self.config.log_interval == 0:
logger.info(msg)
data_start_time = time.time()
except Exception as e:
logger.error(e)
raise e
with Timer("Eval Time Cost: {}"):
total_loss, num_seen_utts = self.valid()
if dist.get_world_size() > 1:
num_seen_utts = paddle.to_tensor(num_seen_utts)
# the default operator in all_reduce function is sum.
dist.all_reduce(num_seen_utts)
total_loss = paddle.to_tensor(total_loss)
dist.all_reduce(total_loss)
cv_loss = total_loss / num_seen_utts
cv_loss = float(cv_loss)
else:
cv_loss = total_loss / num_seen_utts
logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer:
self.visualizer.add_scalar(
tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch()
def setup_dataloader(self):
config = self.config.clone()
self.use_streamdata = config.get("use_stream_data", False)
if self.train:
self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args)
self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args)
logger.info("Setup train/valid Dataloader!")
else:
decode_batch_size = config.get('decode', dict()).get(
'decode_batch_size', 1)
self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args)
self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args)
logger.info("Setup test/align Dataloader!")
def setup_model(self):
config = self.config
model_conf = config
with UpdateConfig(model_conf):
if self.train:
model_conf.input_dim = self.train_loader.feat_dim
model_conf.output_dim = self.train_loader.vocab_size
else:
model_conf.input_dim = self.test_loader.feat_dim
model_conf.output_dim = self.test_loader.vocab_size
model = Wav2vec2ASR.from_config(model_conf)
if self.parallel:
model = paddle.DataParallel(model, find_unused_parameters=True)
logger.info(f"{model}")
layer_tools.print_params(model, logger.info)
self.model = model
logger.info("Setup model!")
# setup speech augmentation for wav2vec2
self.speech_augmentation = TimeDomainSpecAugment()
if not self.train:
return
train_config = config
optim_type = train_config.model_optim
optim_conf = train_config.model_optim_conf
scheduler_type = train_config.scheduler
scheduler_conf = train_config.scheduler_conf
scheduler_args = {
"learning_rate": optim_conf.lr,
"verbose": False,
"warmup_steps": scheduler_conf.warmup_steps,
"gamma": scheduler_conf.lr_decay,
"d_model": model_conf.dnn_neurons,
}
lr_scheduler = LRSchedulerFactory.from_args(scheduler_type,
scheduler_args)
def optimizer_args(
config,
parameters,
lr_scheduler=None, ):
train_config = config
optim_type = train_config.model_optim
optim_conf = train_config.model_optim_conf
scheduler_type = train_config.scheduler
scheduler_conf = train_config.scheduler_conf
return {
"grad_clip": train_config.global_grad_clip,
"learning_rate": lr_scheduler
if lr_scheduler else optim_conf.lr,
"epsilon": optim_conf.epsilon,
"rho": optim_conf.rho,
"parameters": parameters,
"epsilon": 1e-9 if optim_type == 'noam' else None,
"beta1": 0.9 if optim_type == 'noam' else None,
"beat2": 0.98 if optim_type == 'noam' else None,
}
optimzer_args = optimizer_args(config, model.parameters(), lr_scheduler)
optimizer = OptimizerFactory.from_args(optim_type, optimzer_args)
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
logger.info("Setup optimizer/lr_scheduler!")
class Wav2Vec2ASRTester(Wav2Vec2ASRTrainer):
def __init__(self, config, args):
super().__init__(config, args)
self.text_featurizer = TextFeaturizer(
unit_type=config.unit_type, vocab=config.vocab_filepath)
self.vocab_list = self.text_featurizer.vocab_list
def id2token(self, texts, texts_len):
""" ord() id to chr() chr """
trans = []
for text, n in zip(texts, texts_len):
n = n.numpy().item()
ids = text[:n]
trans.append(
self.text_featurizer.defeaturize(ids.numpy().tolist()))
return trans
def compute_metrics(self,
utts,
audio,
audio_len,
texts,
texts_len,
fout=None):
decode_cfg = self.config.decode
errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if decode_cfg.error_rate_type == 'cer' else error_rate.word_errors
error_rate_func = error_rate.cer if decode_cfg.error_rate_type == 'cer' else error_rate.wer
start_time = time.time()
target_transcripts = self.id2token(texts, texts_len)
result_transcripts, result_tokenids = self.model.decode(
audio,
text_feature=self.text_featurizer,
decoding_method=decode_cfg.decoding_method,
beam_size=decode_cfg.beam_size)
decode_time = time.time() - start_time
for utt, target, result, rec_tids in zip(
utts, target_transcripts, result_transcripts, result_tokenids):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
num_ins += 1
if fout:
fout.write({
"utt": utt,
"refs": [target],
"hyps": [result],
"hyps_tokenid": [rec_tids],
})
logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
logger.info("One example error rate [%s] = %f" % (
decode_cfg.error_rate_type, error_rate_func(target, result)))
return dict(
errors_sum=errors_sum,
len_refs=len_refs,
num_ins=num_ins, # num examples
error_rate=errors_sum / len_refs,
error_rate_type=decode_cfg.error_rate_type,
num_frames=audio_len.sum().numpy().item(),
decode_time=decode_time)
@mp_tools.rank_zero_only
@paddle.no_grad()
def test(self):
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
self.model.eval()
error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0
num_time = 0.0
# Initialized the decoder in model
decode_cfg = self.config.decode
vocab_list = self.vocab_list
decode_batch_size = decode_cfg.decode_batch_size
with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
metrics = self.compute_metrics(*batch, fout=fout)
num_frames += metrics['num_frames']
num_time += metrics["decode_time"]
errors_sum += metrics['errors_sum']
len_refs += metrics['len_refs']
num_ins += metrics['num_ins']
error_rate_type = metrics['error_rate_type']
rtf = num_time / (num_frames)
logger.info(
"RTF: %f, Error rate [%s] (%d/?) = %f" %
(rtf, error_rate_type, num_ins, errors_sum / len_refs))
# logging
msg = "Test: "
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "Final error rate [%s] (%d/%d) = %f" % (
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
logger.info(msg)
err_meta_path = os.path.splitext(self.args.result_file)[0] + '.err'
err_type_str = "{}".format(error_rate_type)
with open(err_meta_path, 'w') as f:
data = json.dumps({
"epoch":
self.epoch,
"step":
self.iteration,
"rtf":
rtf,
error_rate_type:
errors_sum / len_refs,
"dataset_hour": (num_frames) / 1000.0 / 3600.0,
"process_hour":
num_time / 1000.0 / 3600.0,
"num_examples":
num_ins,
"err_sum":
errors_sum,
"ref_len":
len_refs,
"decode_method":
self.config.decode.decoding_method,
})
f.write(data + '\n')

@ -0,0 +1,45 @@
"""Vanilla Neural Network for simple tests.
Authors
* Elena Rastorgueva 2020
"""
import paddle
from paddlespeech.s2t.models.wav2vec2.modules import containers
from paddlespeech.s2t.models.wav2vec2.modules import linear
class VanillaNN(containers.Sequential):
"""A simple vanilla Deep Neural Network.
Arguments
---------
activation : paddle class
A class used for constructing the activation layers.
dnn_blocks : int
The number of linear neural blocks to include.
dnn_neurons : int
The number of neurons in the linear layers.
Example
-------
>>> inputs = paddle.rand([10, 120, 60])
>>> model = VanillaNN(input_shape=inputs.shape)
>>> outputs = model(inputs)
>>> outputs.shape
paddle.shape([10, 120, 512])
"""
def __init__(
self,
input_shape,
activation=paddle.nn.LeakyReLU,
dnn_blocks=2,
dnn_neurons=512,
):
super().__init__(input_shape=input_shape)
for block_index in range(dnn_blocks):
self.append(
linear.Linear,
n_neurons=dnn_neurons,
bias=True,
layer_name="linear",
)
self.append(activation(), layer_name="act")

@ -0,0 +1,175 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from packaging import version
from paddle import Tensor, nn
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
class NewGELUActivation(nn.Layer):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + paddle.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * paddle.pow(input, 3.0))))
class GELUActivation(nn.Layer):
"""
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
paddle.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * paddle.pow(x, 3)))) This is now written in C in nn.functional
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
def __init__(self, use_gelu_python: bool = False):
super().__init__()
self.act = nn.functional.gelu
def _gelu_python(self, input: Tensor) -> Tensor:
return input * 0.5 * (1.0 + paddle.erf(input / math.sqrt(2.0)))
def forward(self, input: Tensor) -> Tensor:
return self.act(input)
class FastGELUActivation(nn.Layer):
"""
Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
"""
def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + paddle.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
class QuickGELUActivation(nn.Layer):
"""
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
"""
def forward(self, input: Tensor) -> Tensor:
return input * paddle.sigmoid(1.702 * input)
class ClippedGELUActivation(nn.Layer):
"""
Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
https://arxiv.org/abs/2004.09602.
Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
initially created.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
paddle.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * paddle.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
"""
def __init__(self, min: float, max: float):
if min > max:
raise ValueError(f"min should be < max (got min: {min}, max: {max})")
super().__init__()
self.min = min
self.max = max
def forward(self, x: Tensor) -> Tensor:
return paddle.clip(gelu(x), self.min, self.max)
class SiLUActivation(nn.Layer):
"""
See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
later.
"""
def __init__(self):
super().__init__()
self.act = nn.functional.silu
def _silu_python(self, input: Tensor) -> Tensor:
return input * paddle.sigmoid(input)
def forward(self, input: Tensor) -> Tensor:
return self.act(input)
class MishActivation(nn.Layer):
"""
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
visit the official repository for the paper: https://github.com/digantamisra98/Mish
"""
def __init__(self):
super().__init__()
self.act = nn.functional.mish
def _mish_python(self, input: Tensor) -> Tensor:
return input * paddle.tanh(nn.functional.softplus(input))
def forward(self, input: Tensor) -> Tensor:
return self.act(input)
class LinearActivation(nn.Layer):
"""
Applies the linear activation function, i.e. forwarding input directly to output.
"""
def forward(self, input: Tensor) -> Tensor:
return input
ACT2FN = {
"gelu": GELUActivation(),
"gelu_10": ClippedGELUActivation(-10, 10),
"gelu_fast": FastGELUActivation(),
"gelu_new": NewGELUActivation(),
"gelu_python": GELUActivation(use_gelu_python=True),
"linear": LinearActivation(),
"mish": MishActivation(),
"quick_gelu": QuickGELUActivation(),
"relu": nn.ReLU(),
"sigmoid": nn.Sigmoid(),
"silu": SiLUActivation(),
"swish": SiLUActivation(),
"tanh": nn.Tanh(),
}
def get_activation(activation_string):
if activation_string in ACT2FN:
return ACT2FN[activation_string]
else:
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
# For backwards compatibility with: from activations import gelu_python
gelu_python = get_activation("gelu_python")
gelu_new = get_activation("gelu_new")
gelu = get_activation("gelu")
gelu_fast = get_activation("gelu_fast")
quick_gelu = get_activation("quick_gelu")
silu = get_activation("silu")
mish = get_activation("mish")
linear_act = get_activation("linear")

@ -0,0 +1,131 @@
import paddle
import inspect
import logging
import operator
import functools
class Sequential(paddle.nn.LayerDict):
"""A sequence of modules with potentially inferring shape on construction.
If layers are passed with names, these can be referenced with dot notation.
Arguments
---------
input_shape : iterable
A list or tuple of ints or None, representing the expected shape of an
input tensor. None represents a variable-length dimension. If no
``input_shape`` is passed, no shape inference will be performed.
*layers, **named_layers
The inputs are treated as a list of layers to be
applied in sequence. The output shape of each layer is used to
infer the shape of the following layer. If a tuple is returned,
only the shape of the first element is used to determine input
shape of the next layer (e.g. RNN returns output, hidden).
Example
-------
>>> inputs = paddle.rand(10, 40, 50)
>>> model = Sequential(input_shape=inputs.shape)
>>> model.append(Linear, n_neurons=100, layer_name="layer1")
>>> model.append(Linear, n_neurons=200, layer_name="layer2")
>>> outputs = model(inputs)
>>> outputs.shape
paddle.shape([10, 40, 200])
>>> outputs = model.layer1(inputs)
>>> outputs.shape
paddle.shape([10, 40, 100])
"""
def __init__(self, *layers, input_shape=None, **named_layers):
super().__init__()
# Make sure either layers or input_shape is passed
if not layers and input_shape is None and not named_layers:
raise ValueError("Must pass either layers or input shape")
# Keep track of what layers need "lengths" passed
self.length_layers = []
# Replace None dimensions with arbitrary value
self.input_shape = input_shape
if input_shape and None in input_shape:
self.input_shape = list(input_shape)
for i, dim in enumerate(self.input_shape):
# To reduce size of dummy tensors, use 1 for batch dim
if i == 0 and dim is None:
dim = 1
# Use 64 as nice round arbitrary value, big enough that
# halving this dimension a few times doesn't reach 1
self.input_shape[i] = dim or 256
# Append non-named layers
for layer in layers:
self.append(layer)
# Append named layers
for name, layer in named_layers.items():
self.append(layer, layer_name=name)
def append(self, layer, *args, layer_name=None, **kwargs):
"""Add a layer to the list of layers, inferring shape if necessary.
Arguments
---------
layer : A paddle.nn.Module class or object
If the layer is a class, it should accept an argument called
``input_shape`` which will be inferred and passed. If the layer
is a module object, it is added as-is.
layer_name : str
The name of the layer, for reference. If the name is in use,
``_{count}`` will be appended.
*args, **kwargs
These are passed to the layer if it is constructed.
"""
# Compute layer_name
if layer_name is None:
layer_name = str(len(self))
elif layer_name in self:
index = 0
while f"{layer_name}_{index}" in self:
index += 1
layer_name = f"{layer_name}_{index}"
# Check if it needs to be constructed with input shape
if self.input_shape:
argspec = inspect.getfullargspec(layer)
if "input_shape" in argspec.args + argspec.kwonlyargs:
input_shape = self.get_output_shape()
layer = layer(*args, input_shape=input_shape, **kwargs)
# Finally, append the layer.
try:
self[layer_name] = layer
# self.add_module(layer_name, layer)
except TypeError:
raise ValueError(
"Must pass `input_shape` at initialization and use "
"modules that take `input_shape` to infer shape when "
"using `append()`."
)
def get_output_shape(self):
"""Returns expected shape of the output.
Computed by passing dummy input constructed with the
``self.input_shape`` attribute.
"""
with paddle.no_grad():
dummy_input = paddle.zeros(self.input_shape)
dummy_output = self(dummy_input)
return dummy_output.shape
def forward(self, x):
"""Applies layers in sequence, passing only the first element of tuples.
Arguments
---------
x : paddle.Tensor
The input tensor to run through the network.
"""
for layer in self.values():
x = layer(x)
if isinstance(x, tuple):
x = x[0]
return x

@ -0,0 +1,73 @@
"""Library implementing linear transformation.
Authors
* Mirco Ravanelli 2020
* Davide Borra 2021
"""
import logging
import paddle
import paddle.nn as nn
from paddlespeech.s2t.modules import align
logger = logging.getLogger(__name__)
class Linear(paddle.nn.Layer):
"""Computes a linear transformation y = wx + b.
Arguments
---------
n_neurons : int
It is the number of output neurons (i.e, the dimensionality of the
output).
input_shape: tuple
It is the shape of the input tensor.
input_size: int
Size of the input tensor.
bias : bool
If True, the additive bias b is adopted.
combine_dims : bool
If True and the input is 4D, combine 3rd and 4th dimensions of input.
Example
-------
>>> inputs = paddle.rand(10, 50, 40)
>>> lin_t = Linear(input_shape=(10, 50, 40), n_neurons=100)
>>> output = lin_t(inputs)
>>> output.shape
paddle.shape([10, 50, 100])
"""
def __init__(
self,
n_neurons,
input_shape=None,
input_size=None,
bias=True,
combine_dims=False,
):
super().__init__()
self.combine_dims = combine_dims
if input_shape is None and input_size is None:
raise ValueError("Expected one of input_shape or input_size")
if input_size is None:
input_size = input_shape[-1]
if len(input_shape) == 4 and self.combine_dims:
input_size = input_shape[2] * input_shape[3]
# Weights are initialized following paddle approach
self.w = align.Linear(input_size, n_neurons, bias_attr=bias)
def forward(self, x):
"""Returns the linear transformation of input tensor.
Arguments
---------
x : paddle.Tensor
Input to transform linearly.
"""
if x.rank == 4 and self.combine_dims:
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
wx = self.w(x)
return wx

File diff suppressed because it is too large Load Diff

@ -0,0 +1,242 @@
"""
Low level signal processing utilities
Authors
* Peter Plantinga 2020
* Francois Grondin 2020
* William Aris 2020
* Samuele Cornell 2020
* Sarthak Yadav 2022
"""
import paddle
import math
from packaging import version
import numpy as np
def blackman_window(window_length, periodic=True):
"""Blackman window function.
Arguments
---------
window_length : int
Controlling the returned window size.
periodic : bool
Determines whether the returned window trims off the
last duplicate value from the symmetric window
Returns
-------
A 1-D tensor of size (window_length) containing the window
"""
if window_length == 0:
return []
if window_length == 1:
return paddle.ones([1])
if periodic:
window_length += 1
window = paddle.arange(window_length) * (np.pi / (window_length - 1))
window = 0.08 * paddle.cos(window * 4) - 0.5 * paddle.cos(window * 2) + 0.42
return window[:-1] if periodic else window
def compute_amplitude(waveforms, lengths=None, amp_type="avg", scale="linear"):
"""Compute amplitude of a batch of waveforms.
Arguments
---------
waveform : tensor
The waveforms used for computing amplitude.
Shape should be `[time]` or `[batch, time]` or
`[batch, time, channels]`.
lengths : tensor
The lengths of the waveforms excluding the padding.
Shape should be a single dimension, `[batch]`.
amp_type : str
Whether to compute "avg" average or "peak" amplitude.
Choose between ["avg", "peak"].
scale : str
Whether to compute amplitude in "dB" or "linear" scale.
Choose between ["linear", "dB"].
Returns
-------
The average amplitude of the waveforms.
Example
-------
>>> signal = paddle.sin(paddle.arange(16000.0)).unsqueeze(0)
>>> compute_amplitude(signal, signal.size(1))
tensor([[0.6366]])
"""
if len(waveforms.shape) == 1:
waveforms = waveforms.unsqueeze(0)
assert amp_type in ["avg", "peak"]
assert scale in ["linear", "dB"]
if amp_type == "avg":
if lengths is None:
out = paddle.mean(paddle.abs(waveforms), axis=1, keepdim=True)
else:
wav_sum = paddle.sum(paddle.abs(waveforms), axis=1, keepdim=True)
out = wav_sum / lengths
elif amp_type == "peak":
out = paddle.max(paddle.abs(waveforms), axis=1, keepdim=True)[0]
else:
raise NotImplementedError
if scale == "linear":
return out
elif scale == "dB":
return paddle.clip(20 * paddle.log10(out), min=-80) # clamp zeros
else:
raise NotImplementedError
def convolve1d(
waveform,
kernel,
padding=0,
pad_type="constant",
stride=1,
groups=1,
use_fft=False,
rotation_index=0,
):
"""Use paddle.nn.functional to perform 1d padding and conv.
Arguments
---------
waveform : tensor
The tensor to perform operations on.
kernel : tensor
The filter to apply during convolution.
padding : int or tuple
The padding (pad_left, pad_right) to apply.
If an integer is passed instead, this is passed
to the conv1d function and pad_type is ignored.
pad_type : str
The type of padding to use. Passed directly to
`paddle.nn.functional.pad`, see Paddle documentation
for available options.
stride : int
The number of units to move each time convolution is applied.
Passed to conv1d. Has no effect if `use_fft` is True.
groups : int
This option is passed to `conv1d` to split the input into groups for
convolution. Input channels should be divisible by the number of groups.
use_fft : bool
When `use_fft` is passed `True`, then compute the convolution in the
spectral domain using complex multiply. This is more efficient on CPU
when the size of the kernel is large (e.g. reverberation). WARNING:
Without padding, circular convolution occurs. This makes little
difference in the case of reverberation, but may make more difference
with different kernels.
rotation_index : int
This option only applies if `use_fft` is true. If so, the kernel is
rolled by this amount before convolution to shift the output location.
Returns
-------
The convolved waveform.
Example
-------
>>> from speechbrain.dataio.dataio import read_audio
>>> signal = read_audio('tests/samples/single-mic/example1.wav')
>>> signal = signal.unsqueeze(0).unsqueeze(2)
>>> kernel = paddle.rand([1, 10, 1])
>>> signal = convolve1d(signal, kernel, padding=(9, 0))
"""
if len(waveform.shape) != 3:
raise ValueError("Convolve1D expects a 3-dimensional tensor")
# Move time dimension last, which pad and fft and conv expect.
waveform = waveform.transpose([0, 2, 1])
kernel = kernel.transpose([0, 2, 1])
# Padding can be a tuple (left_pad, right_pad) or an int
if isinstance(padding, tuple):
waveform = paddle.nn.functional.pad(
x=waveform, pad=padding, mode=pad_type, data_format='NCL'
)
# This approach uses FFT, which is more efficient if the kernel is large
if use_fft:
# Pad kernel to same length as signal, ensuring correct alignment
zero_length = waveform.shape[-1] - kernel.shape[-1]
# Handle case where signal is shorter
if zero_length < 0:
kernel = kernel[..., :zero_length]
zero_length = 0
# Perform rotation to ensure alignment
zeros = paddle.zeros(
[kernel.shape[0], kernel.shape[1], zero_length],
dtype=kernel.dtype
)
after_index = kernel[..., rotation_index:]
before_index = kernel[..., :rotation_index]
kernel = paddle.concat((after_index, zeros, before_index), axis=-1)
# Multiply in frequency domain to convolve in time domain
import paddle.fft as fft
result = fft.rfft(waveform) * fft.rfft(kernel)
convolved = fft.irfft(result, n=waveform.shape[-1])
# Use the implementation given by paddle, which should be efficient on GPU
else:
convolved = paddle.nn.functional.conv1d(
x=waveform,
weight=kernel,
stride=stride,
groups=groups,
padding=padding if not isinstance(padding, tuple) else 0,
)
# Return time dimension to the second dimension.
return convolved.transpose([0, 2, 1])
def notch_filter(notch_freq, filter_width=101, notch_width=0.05):
"""Returns a notch filter constructed from a high-pass and low-pass filter.
(from https://tomroelandts.com/articles/
how-to-create-simple-band-pass-and-band-reject-filters)
Arguments
---------
notch_freq : float
frequency to put notch as a fraction of the
sampling rate / 2. The range of possible inputs is 0 to 1.
filter_width : int
Filter width in samples. Longer filters have
smaller transition bands, but are more inefficient.
notch_width : float
Width of the notch, as a fraction of the sampling_rate / 2.
"""
# Check inputs
assert 0 < notch_freq <= 1
assert filter_width % 2 != 0
pad = filter_width // 2
inputs = paddle.arange(filter_width) - pad
# Avoid frequencies that are too low
notch_freq += notch_width
# Define sinc function, avoiding division by zero
def sinc(x):
"Computes the sinc function."
def _sinc(x):
return paddle.sin(x) / x
# The zero is at the middle index
return paddle.concat([_sinc(x[:pad]), paddle.ones([1]), _sinc(x[pad + 1 :])])
# Compute a low-pass filter with cutoff frequency notch_freq.
hlpf = sinc(3 * (notch_freq - notch_width) * inputs)
hlpf *= blackman_window(filter_width)
hlpf /= paddle.sum(hlpf)
# Compute a high-pass filter with cutoff frequency notch_freq.
hhpf = sinc(3 * (notch_freq + notch_width) * inputs)
hhpf *= blackman_window(filter_width)
hhpf /= -paddle.sum(hhpf)
hhpf[pad] += 1
# Adding filters creates notch filter
return (hlpf + hhpf).view(1, -1, 1)

@ -0,0 +1,727 @@
import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddlespeech.s2t.models.wav2vec2.processing.signal_processing import (
compute_amplitude,
convolve1d,
notch_filter)
class SpeedPerturb(nn.Layer):
"""Slightly speed up or slow down an audio signal.
Resample the audio signal at a rate that is similar to the original rate,
to achieve a slightly slower or slightly faster signal. This technique is
outlined in the paper: "Audio Augmentation for Speech Recognition"
Arguments
---------
orig_freq : int
The frequency of the original signal.
speeds : list
The speeds that the signal should be changed to, as a percentage of the
original signal (i.e. `speeds` is divided by 100 to get a ratio).
perturb_prob : float
The chance that the batch will be speed-
perturbed. By default, every batch is perturbed.
Example
-------
>>> from speechbrain.dataio.dataio import read_audio
>>> signal = read_audio('tests/samples/single-mic/example1.wav')
>>> perturbator = SpeedPerturb(orig_freq=16000, speeds=[90])
>>> clean = signal.unsqueeze(0)
>>> perturbed = perturbator(clean)
>>> clean.shape
paddle.shape([1, 52173])
>>> perturbed.shape
paddle.shape([1, 46956])
"""
def __init__(
self, orig_freq, speeds=[90, 100, 110], perturb_prob=1.0,
):
super().__init__()
self.orig_freq = orig_freq
self.speeds = speeds
self.perturb_prob = perturb_prob
# Initialize index of perturbation
self.samp_index = 0
# Initialize resamplers
self.resamplers = []
for speed in self.speeds:
config = {
"orig_freq": self.orig_freq,
"new_freq": self.orig_freq * speed // 100,
}
self.resamplers.append(Resample(**config))
def forward(self, waveform):
"""
Arguments
---------
waveforms : tensor
Shape should be `[batch, time]` or `[batch, time, channels]`.
lengths : tensor
Shape should be a single dimension, `[batch]`.
Returns
-------
Tensor of shape `[batch, time]` or `[batch, time, channels]`.
"""
# Don't perturb (return early) 1-`perturb_prob` portion of the batches
if paddle.rand([1]) > self.perturb_prob:
return waveform.clone()
# Perform a random perturbation
self.samp_index = paddle.randint(len(self.speeds), shape=(1,))[0]
perturbed_waveform = self.resamplers[self.samp_index](waveform)
return perturbed_waveform
class Resample(nn.Layer):
"""This class resamples an audio signal using sinc-based interpolation.
It is a modification of the `resample` function from torchaudio
(https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html)
Arguments
---------
orig_freq : int
the sampling frequency of the input signal.
new_freq : int
the new sampling frequency after this operation is performed.
lowpass_filter_width : int
Controls the sharpness of the filter, larger numbers result in a
sharper filter, but they are less efficient. Values from 4 to 10 are allowed.
"""
def __init__(
self, orig_freq=16000, new_freq=16000, lowpass_filter_width=6,
):
super().__init__()
self.orig_freq = orig_freq
self.new_freq = new_freq
self.lowpass_filter_width = lowpass_filter_width
# Compute rate for striding
self._compute_strides()
assert self.orig_freq % self.conv_stride == 0
assert self.new_freq % self.conv_transpose_stride == 0
def _compute_strides(self):
"""Compute the phases in polyphase filter.
(almost directly from torchaudio.compliance.kaldi)
"""
# Compute new unit based on ratio of in/out frequencies
base_freq = math.gcd(self.orig_freq, self.new_freq)
input_samples_in_unit = self.orig_freq // base_freq
self.output_samples = self.new_freq // base_freq
# Store the appropriate stride based on the new units
self.conv_stride = input_samples_in_unit
self.conv_transpose_stride = self.output_samples
def forward(self, waveforms):
"""
Arguments
---------
waveforms : tensor
Shape should be `[batch, time]` or `[batch, time, channels]`.
lengths : tensor
Shape should be a single dimension, `[batch]`.
Returns
-------
Tensor of shape `[batch, time]` or `[batch, time, channels]`.
"""
if not hasattr(self, "first_indices"):
self._indices_and_weights(waveforms)
# Don't do anything if the frequencies are the same
if self.orig_freq == self.new_freq:
return waveforms
unsqueezed = False
if len(waveforms.shape) == 2:
waveforms = waveforms.unsqueeze(1)
unsqueezed = True
elif len(waveforms.shape) == 3:
waveforms = waveforms.transpose([0, 2, 1])
else:
raise ValueError("Input must be 2 or 3 dimensions")
# Do resampling
resampled_waveform = self._perform_resample(waveforms)
if unsqueezed:
resampled_waveform = resampled_waveform.squeeze(1)
else:
resampled_waveform = resampled_waveform.transpose([0, 2, 1])
return resampled_waveform
def _perform_resample(self, waveforms):
"""Resamples the waveform at the new frequency.
This matches Kaldi's OfflineFeatureTpl ResampleWaveform which uses a
LinearResample (resample a signal at linearly spaced intervals to
up/downsample a signal). LinearResample (LR) means that the output
signal is at linearly spaced intervals (i.e the output signal has a
frequency of `new_freq`). It uses sinc/bandlimited interpolation to
upsample/downsample the signal.
(almost directly from torchaudio.compliance.kaldi)
https://ccrma.stanford.edu/~jos/resample/
Theory_Ideal_Bandlimited_Interpolation.html
https://github.com/kaldi-asr/kaldi/blob/master/src/feat/resample.h#L56
Arguments
---------
waveforms : tensor
The batch of audio signals to resample.
Returns
-------
The waveforms at the new frequency.
"""
# Compute output size and initialize
batch_size, num_channels, wave_len = waveforms.shape
window_size = self.weights.shape[1]
tot_output_samp = self._output_samples(wave_len)
resampled_waveform = paddle.zeros(
(batch_size, num_channels, tot_output_samp)
)
# self.weights = self.weights.to(waveforms.device)
# Check weights are on correct device
# if waveforms.device != self.weights.device:
# self.weights = self.weights.to(waveforms.device)
# eye size: (num_channels, num_channels, 1)
eye = paddle.eye(num_channels).unsqueeze(2)
# Iterate over the phases in the polyphase filter
for i in range(self.first_indices.shape[0]):
wave_to_conv = waveforms
first_index = int(self.first_indices[i].item())
if first_index >= 0:
# trim the signal as the filter will not be applied
# before the first_index
wave_to_conv = wave_to_conv[..., first_index:]
# pad the right of the signal to allow partial convolutions
# meaning compute values for partial windows (e.g. end of the
# window is outside the signal length)
max_index = (tot_output_samp - 1) // self.output_samples
end_index = max_index * self.conv_stride + window_size
current_wave_len = wave_len - first_index
right_padding = max(0, end_index + 1 - current_wave_len)
left_padding = max(0, -first_index)
wave_to_conv = paddle.nn.functional.pad(
wave_to_conv, (left_padding, right_padding), data_format='NCL'
)
conv_wave = paddle.nn.functional.conv1d(
x=wave_to_conv,
weight=self.weights[i].repeat(num_channels, 1, 1),
stride=self.conv_stride,
groups=num_channels,
)
# we want conv_wave[:, i] to be at
# output[:, i + n*conv_transpose_stride]
dilated_conv_wave = paddle.nn.functional.conv1d_transpose(
conv_wave, eye, stride=self.conv_transpose_stride
)
# pad dilated_conv_wave so it reaches the output length if needed.
left_padding = i
previous_padding = left_padding + dilated_conv_wave.shape[-1]
right_padding = max(0, tot_output_samp - previous_padding)
dilated_conv_wave = paddle.nn.functional.pad(
dilated_conv_wave, (left_padding, right_padding), data_format='NCL'
)
dilated_conv_wave = dilated_conv_wave[..., :tot_output_samp]
resampled_waveform += dilated_conv_wave
return resampled_waveform
def _output_samples(self, input_num_samp):
"""Based on LinearResample::GetNumOutputSamples.
LinearResample (LR) means that the output signal is at
linearly spaced intervals (i.e the output signal has a
frequency of ``new_freq``). It uses sinc/bandlimited
interpolation to upsample/downsample the signal.
(almost directly from torchaudio.compliance.kaldi)
Arguments
---------
input_num_samp : int
The number of samples in each example in the batch.
Returns
-------
Number of samples in the output waveform.
"""
# For exact computation, we measure time in "ticks" of 1.0 / tick_freq,
# where tick_freq is the least common multiple of samp_in and
# samp_out.
samp_in = int(self.orig_freq)
samp_out = int(self.new_freq)
tick_freq = abs(samp_in * samp_out) // math.gcd(samp_in, samp_out)
ticks_per_input_period = tick_freq // samp_in
# work out the number of ticks in the time interval
# [ 0, input_num_samp/samp_in ).
interval_length = input_num_samp * ticks_per_input_period
if interval_length <= 0:
return 0
ticks_per_output_period = tick_freq // samp_out
# Get the last output-sample in the closed interval,
# i.e. replacing [ ) with [ ]. Note: integer division rounds down.
# See http://en.wikipedia.org/wiki/Interval_(mathematics) for an
# explanation of the notation.
last_output_samp = interval_length // ticks_per_output_period
# We need the last output-sample in the open interval, so if it
# takes us to the end of the interval exactly, subtract one.
if last_output_samp * ticks_per_output_period == interval_length:
last_output_samp -= 1
# First output-sample index is zero, so the number of output samples
# is the last output-sample plus one.
num_output_samp = last_output_samp + 1
return num_output_samp
def _indices_and_weights(self, waveforms):
"""Based on LinearResample::SetIndexesAndWeights
Retrieves the weights for resampling as well as the indices in which
they are valid. LinearResample (LR) means that the output signal is at
linearly spaced intervals (i.e the output signal has a frequency
of ``new_freq``). It uses sinc/bandlimited interpolation to
upsample/downsample the signal.
Returns
-------
- the place where each filter should start being applied
- the filters to be applied to the signal for resampling
"""
# Lowpass filter frequency depends on smaller of two frequencies
min_freq = min(self.orig_freq, self.new_freq)
lowpass_cutoff = 0.99 * 0.5 * min_freq
assert lowpass_cutoff * 2 <= min_freq
window_width = self.lowpass_filter_width / (2.0 * lowpass_cutoff)
assert lowpass_cutoff < min(self.orig_freq, self.new_freq) / 2
output_t = paddle.arange(
start=0.0, end=self.output_samples
)
output_t /= self.new_freq
min_t = output_t - window_width
max_t = output_t + window_width
min_input_index = paddle.ceil(min_t * self.orig_freq)
max_input_index = paddle.floor(max_t * self.orig_freq)
num_indices = max_input_index - min_input_index + 1
max_weight_width = num_indices.max()
j = paddle.arange(max_weight_width)
input_index = min_input_index.unsqueeze(1) + j.unsqueeze(0)
delta_t = (input_index / self.orig_freq) - output_t.unsqueeze(1)
weights = paddle.zeros_like(delta_t)
inside_window_indices = delta_t.abs() < (window_width)
# raised-cosine (Hanning) window with width `window_width`
weights[inside_window_indices] = 0.5 * (
1
+ paddle.cos(
2
* math.pi
* lowpass_cutoff
/ self.lowpass_filter_width
* delta_t[inside_window_indices]
)
)
t_eq_zero_indices = delta_t == 0.0
t_not_eq_zero_indices = ~t_eq_zero_indices
# sinc filter function
weights[t_not_eq_zero_indices] *= paddle.sin(
2 * math.pi * lowpass_cutoff * delta_t[t_not_eq_zero_indices]
) / (math.pi * delta_t[t_not_eq_zero_indices])
# limit of the function at t = 0
weights[t_eq_zero_indices] *= 2 * lowpass_cutoff
# size (output_samples, max_weight_width)
weights /= self.orig_freq
self.first_indices = min_input_index
self.weights = weights
class DropFreq(nn.Layer):
"""This class drops a random frequency from the signal.
The purpose of this class is to teach models to learn to rely on all parts
of the signal, not just a few frequency bands.
Arguments
---------
drop_freq_low : float
The low end of frequencies that can be dropped,
as a fraction of the sampling rate / 2.
drop_freq_high : float
The high end of frequencies that can be
dropped, as a fraction of the sampling rate / 2.
drop_count_low : int
The low end of number of frequencies that could be dropped.
drop_count_high : int
The high end of number of frequencies that could be dropped.
drop_width : float
The width of the frequency band to drop, as
a fraction of the sampling_rate / 2.
drop_prob : float
The probability that the batch of signals will have a frequency
dropped. By default, every batch has frequencies dropped.
Example
-------
>>> from speechbrain.dataio.dataio import read_audio
>>> dropper = DropFreq()
>>> signal = read_audio('tests/samples/single-mic/example1.wav')
>>> dropped_signal = dropper(signal.unsqueeze(0))
"""
def __init__(
self,
drop_freq_low=1e-14,
drop_freq_high=1,
drop_count_low=1,
drop_count_high=2,
drop_width=0.05,
drop_prob=1,
):
super().__init__()
self.drop_freq_low = drop_freq_low
self.drop_freq_high = drop_freq_high
self.drop_count_low = drop_count_low
self.drop_count_high = drop_count_high
self.drop_width = drop_width
self.drop_prob = drop_prob
def forward(self, waveforms):
"""
Arguments
---------
waveforms : tensor
Shape should be `[batch, time]` or `[batch, time, channels]`.
Returns
-------
Tensor of shape `[batch, time]` or `[batch, time, channels]`.
"""
# Don't drop (return early) 1-`drop_prob` portion of the batches
dropped_waveform = waveforms.clone()
if paddle.rand([1]) > self.drop_prob:
return dropped_waveform
# Add channels dimension
if len(waveforms.shape) == 2:
dropped_waveform = dropped_waveform.unsqueeze(-1)
# Pick number of frequencies to drop
drop_count = paddle.randint(
low=self.drop_count_low, high=self.drop_count_high + 1, shape=(1,),
)
# Pick a frequency to drop
drop_range = self.drop_freq_high - self.drop_freq_low
drop_frequency = (
paddle.rand(drop_count) * drop_range + self.drop_freq_low
)
# Filter parameters
filter_length = 101
pad = filter_length // 2
# Start with delta function
drop_filter = paddle.zeros([1, filter_length, 1])
drop_filter[0, pad, 0] = 1
# Subtract each frequency
for frequency in drop_frequency:
notch_kernel = notch_filter(
frequency, filter_length, self.drop_width,
)
drop_filter = convolve1d(drop_filter, notch_kernel, pad)
# Apply filter
dropped_waveform = convolve1d(dropped_waveform, drop_filter, pad)
# Remove channels dimension if added
return dropped_waveform.squeeze(-1)
class DropChunk(nn.Layer):
"""This class drops portions of the input signal.
Using `DropChunk` as an augmentation strategy helps a models learn to rely
on all parts of the signal, since it can't expect a given part to be
present.
Arguments
---------
drop_length_low : int
The low end of lengths for which to set the
signal to zero, in samples.
drop_length_high : int
The high end of lengths for which to set the
signal to zero, in samples.
drop_count_low : int
The low end of number of times that the signal
can be dropped to zero.
drop_count_high : int
The high end of number of times that the signal
can be dropped to zero.
drop_start : int
The first index for which dropping will be allowed.
drop_end : int
The last index for which dropping will be allowed.
drop_prob : float
The probability that the batch of signals will
have a portion dropped. By default, every batch
has portions dropped.
noise_factor : float
The factor relative to average amplitude of an utterance
to use for scaling the white noise inserted. 1 keeps
the average amplitude the same, while 0 inserts all 0's.
Example
-------
>>> from speechbrain.dataio.dataio import read_audio
>>> dropper = DropChunk(drop_start=100, drop_end=200, noise_factor=0.)
>>> signal = read_audio('tests/samples/single-mic/example1.wav')
>>> signal = signal.unsqueeze(0) # [batch, time, channels]
>>> length = paddle.ones([1])
>>> dropped_signal = dropper(signal, length)
>>> float(dropped_signal[:, 150])
0.0
"""
def __init__(
self,
drop_length_low=100,
drop_length_high=1000,
drop_count_low=1,
drop_count_high=10,
drop_start=0,
drop_end=None,
drop_prob=1,
noise_factor=0.0,
):
super().__init__()
self.drop_length_low = drop_length_low
self.drop_length_high = drop_length_high
self.drop_count_low = drop_count_low
self.drop_count_high = drop_count_high
self.drop_start = drop_start
self.drop_end = drop_end
self.drop_prob = drop_prob
self.noise_factor = noise_factor
# Validate low < high
if drop_length_low > drop_length_high:
raise ValueError("Low limit must not be more than high limit")
if drop_count_low > drop_count_high:
raise ValueError("Low limit must not be more than high limit")
# Make sure the length doesn't exceed end - start
if drop_end is not None and drop_end >= 0:
if drop_start > drop_end:
raise ValueError("Low limit must not be more than high limit")
drop_range = drop_end - drop_start
self.drop_length_low = min(drop_length_low, drop_range)
self.drop_length_high = min(drop_length_high, drop_range)
def forward(self, waveforms, lengths):
"""
Arguments
---------
waveforms : tensor
Shape should be `[batch, time]` or `[batch, time, channels]`.
lengths : tensor
Shape should be a single dimension, `[batch]`.
Returns
-------
Tensor of shape `[batch, time]` or
`[batch, time, channels]`
"""
# Reading input list
lengths = (lengths * waveforms.shape[1]).long()
batch_size = waveforms.shape[0]
dropped_waveform = waveforms.clone()
# Don't drop (return early) 1-`drop_prob` portion of the batches
if paddle.rand([1]) > self.drop_prob:
return dropped_waveform
# Store original amplitude for computing white noise amplitude
clean_amplitude = compute_amplitude(waveforms, lengths.unsqueeze(1))
# Pick a number of times to drop
drop_times = paddle.randint(
low=self.drop_count_low,
high=self.drop_count_high + 1,
shape=(batch_size,),
)
# Iterate batch to set mask
for i in range(batch_size):
if drop_times[i] == 0:
continue
# Pick lengths
length = paddle.randint(
low=self.drop_length_low,
high=self.drop_length_high + 1,
shape=(drop_times[i],),
)
# Compute range of starting locations
start_min = self.drop_start
if start_min < 0:
start_min += lengths[i]
start_max = self.drop_end
if start_max is None:
start_max = lengths[i]
if start_max < 0:
start_max += lengths[i]
start_max = max(0, start_max - length.max())
# Pick starting locations
start = paddle.randint(
low=start_min, high=start_max + 1, shape=(drop_times[i],),
)
end = start + length
# Update waveform
if not self.noise_factor:
for j in range(drop_times[i]):
dropped_waveform[i, start[j] : end[j]] = 0.0
else:
# Uniform distribution of -2 to +2 * avg amplitude should
# preserve the average for normalization
noise_max = 2 * clean_amplitude[i] * self.noise_factor
for j in range(drop_times[i]):
# zero-center the noise distribution
noise_vec = paddle.rand([length[j]])
noise_vec = 2 * noise_max * noise_vec - noise_max
dropped_waveform[i, start[j] : end[j]] = noise_vec
return dropped_waveform
class TimeDomainSpecAugment(nn.Layer):
"""A time-domain approximation of the SpecAugment algorithm.
This augmentation module implements three augmentations in
the time-domain.
1. Drop chunks of the audio (zero amplitude or white noise)
2. Drop frequency bands (with band-drop filters)
3. Speed peturbation (via resampling to slightly different rate)
Arguments
---------
perturb_prob : float from 0 to 1
The probability that a batch will have speed perturbation applied.
drop_freq_prob : float from 0 to 1
The probability that a batch will have frequencies dropped.
drop_chunk_prob : float from 0 to 1
The probability that a batch will have chunks dropped.
speeds : list of ints
A set of different speeds to use to perturb each batch.
See ``speechbrain.processing.speech_augmentation.SpeedPerturb``
sample_rate : int
Sampling rate of the input waveforms.
drop_freq_count_low : int
Lowest number of frequencies that could be dropped.
drop_freq_count_high : int
Highest number of frequencies that could be dropped.
drop_chunk_count_low : int
Lowest number of chunks that could be dropped.
drop_chunk_count_high : int
Highest number of chunks that could be dropped.
drop_chunk_length_low : int
Lowest length of chunks that could be dropped.
drop_chunk_length_high : int
Highest length of chunks that could be dropped.
drop_chunk_noise_factor : float
The noise factor used to scale the white noise inserted, relative to
the average amplitude of the utterance. Default 0 (no noise inserted).
Example
-------
>>> inputs = paddle.randn([10, 16000])
>>> feature_maker = TimeDomainSpecAugment(speeds=[80])
>>> feats = feature_maker(inputs, paddle.ones(10))
>>> feats.shape
paddle.shape([10, 12800])
"""
def __init__(
self,
perturb_prob=1.0,
drop_freq_prob=1.0,
drop_chunk_prob=1.0,
speeds=[95, 100, 105],
sample_rate=16000,
drop_freq_count_low=0,
drop_freq_count_high=3,
drop_chunk_count_low=0,
drop_chunk_count_high=5,
drop_chunk_length_low=1000,
drop_chunk_length_high=2000,
drop_chunk_noise_factor=0,
):
super().__init__()
self.speed_perturb = SpeedPerturb(
perturb_prob=perturb_prob, orig_freq=sample_rate, speeds=speeds
)
self.drop_freq = DropFreq(
drop_prob=drop_freq_prob,
drop_count_low=drop_freq_count_low,
drop_count_high=drop_freq_count_high,
)
self.drop_chunk = DropChunk(
drop_prob=drop_chunk_prob,
drop_count_low=drop_chunk_count_low,
drop_count_high=drop_chunk_count_high,
drop_length_low=drop_chunk_length_low,
drop_length_high=drop_chunk_length_high,
noise_factor=drop_chunk_noise_factor,
)
def forward(self, waveforms, lengths):
"""Returns the distorted waveforms.
Arguments
---------
waveforms : tensor
The waveforms to distort
"""
# Augmentation
with paddle.no_grad():
waveforms = self.speed_perturb(waveforms)
waveforms = self.drop_freq(waveforms)
waveforms = self.drop_chunk(waveforms, lengths)
return waveforms

@ -0,0 +1,247 @@
import numpy as np
import os
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2 import Wav2Vec2ConfigPure
from paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2 import Wav2Vec2Model
from paddlespeech.s2t.modules.mask import make_pad_mask
from paddlespeech.s2t.utils.utility import log_add
from collections import defaultdict
from paddlespeech.s2t.models.wav2vec2.modules.VanillaNN import VanillaNN
from paddlespeech.s2t.modules.ctc import CTCDecoderBase as CTC
from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank
from yacs.config import CfgNode
class Wav2vec2ASR(nn.Layer):
def __init__(self, config: dict):
super().__init__()
wav2vec2_config = Wav2Vec2ConfigPure(config)
wav2vec2 = Wav2Vec2Model(wav2vec2_config)
model_dict = paddle.load(config.wav2vec2_params_path)
wav2vec2.set_state_dict(model_dict)
self.normalize_wav = config.normalize_wav
self.output_norm = config.output_norm
if config.freeze_wav2vec2:
wav2vec2.eval()
for parm in wav2vec2.parameters():
parm.trainable = False
self.wav2vec2 = wav2vec2
self.enc = VanillaNN(input_shape=[None,None,wav2vec2_config.hidden_size], activation=nn.LeakyReLU, dnn_blocks=config.dnn_blocks, dnn_neurons=config.dnn_neurons)
self.ctc = CTC(odim=config.output_dim, enc_n_units=config.dnn_neurons, blank_id=config.blank_id, dropout_rate=config.ctc_dropout_rate, reduction=True)
def forward(self, wav, wavs_lens_rate, target, target_lens_rate):
if self.normalize_wav:
wav = F.layer_norm(wav, wav.shape[1:])
# Extract wav2vec output
out = self.wav2vec2(wav)[0]
# We normalize the output if required
if self.output_norm:
out = F.layer_norm(out, out.shape[1:])
feats = out
x = self.enc(feats)
x_lens = (wavs_lens_rate * x.shape[1]).round().astype(paddle.int64)
target_lens = (target_lens_rate * target.shape[1]).round().astype(paddle.int64)
ctc_loss = self.ctc(x, x_lens, target, target_lens)
return ctc_loss
@paddle.no_grad()
def decode(self,
feats: paddle.Tensor,
text_feature: Dict[str, int],
decoding_method: str,
beam_size: int):
batch_size = feats.shape[0]
if decoding_method is 'ctc_prefix_beam_search' and batch_size > 1:
logger.error(
f'decoding mode {decoding_method} must be running with batch_size == 1'
)
logger.error(f"current batch_size is {batch_size}")
sys.exit(1)
if decoding_method == 'ctc_greedy_search':
hyps = self.ctc_greedy_search(feats)
res = [text_feature.defeaturize(hyp) for hyp in hyps]
res_tokenids = [hyp for hyp in hyps]
# ctc_prefix_beam_search and attention_rescoring only return one
# result in List[int], change it to List[List[int]] for compatible
# with other batch decoding mode
elif decoding_method == 'ctc_prefix_beam_search':
assert feats.shape[0] == 1
hyp = self.ctc_prefix_beam_search(
feats,
beam_size)
res = [text_feature.defeaturize(hyp)]
res_tokenids = [hyp]
else:
raise ValueError(f"wav2vec2 not support decoding method: {decoding_method}")
return res, res_tokenids
@classmethod
def from_config(cls, config):
model = cls(config)
return model
def ctc_greedy_search(
self, wav) -> List[List[int]]:
""" Apply CTC greedy search
Args:
speech (paddle.Tensor): (batch, max_len)
speech_length (paddle.Tensor): (batch, )
Returns:
List[List[int]]: best path result
"""
batch_size = wav.shape[0]
wav = wav[:,:,0]
if self.normalize_wav:
wav = F.layer_norm(wav, wav.shape[1:])
# Extract wav2vec output
out = self.wav2vec2(wav)[0]
# We normalize the output if required
if self.output_norm:
out = F.layer_norm(out, out.shape[1:])
feats = out
x = self.enc(feats)
x_lens = x.shape[1]
ctc_probs = self.ctc.log_softmax(x) # (B, maxlen, vocab_size)
topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1)
topk_index = topk_index.view(batch_size, x_lens) # (B, maxlen)
hyps = [hyp.tolist() for hyp in topk_index]
hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
return hyps
def _ctc_prefix_beam_search(
self, wav, beam_size, blank_id: int=0, ) -> Tuple[List[Tuple[int, float]], paddle.Tensor]:
""" CTC prefix beam search inner implementation
Args:
speech (paddle.Tensor): (batch, max_len, feat_dim)
speech_length (paddle.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
List[Tuple[int, float]]: nbest results, (N,1), (text, likelihood)
paddle.Tensor: encoder output, (1, max_len, encoder_dim),
it will be used for rescoring in attention rescoring mode
"""
wav = wav[:,:,0]
if self.normalize_wav:
wav = F.layer_norm(wav, wav.shape[1:])
# Extract wav2vec output
out = self.wav2vec2(wav)[0]
# We normalize the output if required
if self.output_norm:
out = F.layer_norm(out, out.shape[1:])
feats = out
x = self.enc(feats)
maxlen = x.shape[1]
ctc_probs = self.ctc.log_softmax(x) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0)
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# blank_ending_score and none_blank_ending_score in ln domain
cur_hyps = [(tuple(), (0.0, -float('inf')))]
# 2. CTC beam search step by step
for t in range(0, maxlen):
logp = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
# 2.1 First beam prune: select topk best
top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,)
for s in top_k_index:
s = s.item()
ps = logp[s].item()
for prefix, (pb, pnb) in cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == blank_id: # blank
n_pb, n_pnb = next_hyps[prefix]
n_pb = log_add([n_pb, pb + ps, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
elif s == last:
# Update *ss -> *s;
n_pb, n_pnb = next_hyps[prefix]
n_pnb = log_add([n_pnb, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
else:
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps, pnb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
# 2.2 Second beam prune
next_hyps = sorted(
next_hyps.items(),
key=lambda x: log_add(list(x[1])),
reverse=True)
cur_hyps = next_hyps[:beam_size]
hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps]
return hyps
def ctc_prefix_beam_search(self, wav, beam_size) -> List[int]:
""" Apply CTC prefix beam search
Args:
speech (paddle.Tensor): (batch, max_len, feat_dim)
speech_length (paddle.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
List[int]: CTC prefix beam search nbest results
"""
hyps = self._ctc_prefix_beam_search(
wav, beam_size)
return hyps[0][0]
# @jit.to_static
# def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
# """ Export interface for c++ call, apply linear transform and log
# softmax before ctc
# Args:
# xs (paddle.Tensor): encoder output, (B, T, D)
# Returns:
# paddle.Tensor: activation before ctc
# """
# return self.ctc.log_softmax(xs)
# def _get_data(self):
# data_dir = "data"
# wavs = np.load(os.path.join(data_dir, "wavs.npy"))
# wavs_lens = np.load(os.path.join(data_dir, "wavs_lens.npy"))
# tokens = np.load(os.path.join(data_dir, "tokens.npy"))
# tokens_lens = np.load(os.path.join(data_dir, "tokens_lens.npy"))
# batch = (paddle.to_tensor(wavs), paddle.to_tensor(wavs_lens, dtype='float32'),
# paddle.to_tensor(tokens, dtype='int32'), paddle.to_tensor(tokens_lens, dtype='float32'))
# return batch
Loading…
Cancel
Save