[ASR] Support Hubert, fintuned on the librispeech dataset (#3088)

* librispeech hubert, test=asr

* librispeech hubert, test=asr

* hubert decode

* review

* copyright, notes, example related

* hubert cli

* pre-commit format

* fix conflicts

* fix conflicts

* doc related

* doc and train config

* librispeech.py

* support hubert cli
pull/3221/head
TianHao Zhang 2 years ago committed by GitHub
parent 8205343c65
commit 12e3e76092
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -36,7 +36,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
```
Arguments:
- `input`(required): Audio file to recognize.
- `model`: Model type of asr task. Default: `wav2vec2ASR_librispeech`.
- `model`: Model type of asr task. Default: `wav2vec2`, choices: [wav2vec2, hubert].
- `task`: Output type. Default: `asr`.
- `lang`: Model language. Default: `en`.
- `sample_rate`: Sample rate of the model. Default: `16000`.
@ -56,7 +56,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
# to recognize text
text = ssl_executor(
model='wav2vec2ASR_librispeech',
model='wav2vec2',
task='asr',
lang='en',
sample_rate=16000,

@ -36,7 +36,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
```
参数:
- `input`(必须输入):用于识别的音频文件。
- `model`ASR 任务的模型,默认值:`wav2vec2ASR_librispeech`
- `model`ASR 任务的模型,默认值:`wav2vec2`, 可选项:[wav2vec2, hubert]
- `task`:输出类别,默认值:`asr`。
- `lang`:模型语言,默认值:`en`。
- `sample_rate`:音频采样率,默认值:`16000`。
@ -56,7 +56,7 @@ wget -c https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav
# 识别文本
text = ssl_executor(
model='wav2vec2ASR_librispeech',
model='wav2vec2,
task='asr',
lang='en',
sample_rate=16000,

@ -26,6 +26,8 @@ Model | Pre-Train Method | Pre-Train Data | Finetune Data | Size | Descriptions
[Wav2vec2ASR-large-960h-librispeech Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr3/wav2vec2ASR-large-960h-librispeech_ckpt_1.3.1.model.tar.gz) | wav2vec2 | Librispeech and LV-60k Dataset (5.3w h) | Librispeech (960 h) | 718 MB |Encoder: Wav2vec2.0, Decoder: CTC, Decoding method: Greedy search | - | 0.0189 | [Wav2vecASR Librispeech ASR3](../../examples/librispeech/asr3) |
[Wav2vec2-large-wenetspeech-self Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr3/wav2vec2-large-wenetspeech-self_ckpt_1.3.0.model.tar.gz) | wav2vec2 | Wenetspeech Dataset (1w h) | - | 714 MB |Pre-trained Wav2vec2.0 Model | - | - | - |
[Wav2vec2ASR-large-aishell1 Model](https://paddlespeech.bj.bcebos.com/s2t/aishell/asr3/wav2vec2ASR-large-aishell1_ckpt_1.4.0.model.tar.gz) | wav2vec2 | Wenetspeech Dataset (1w h) | aishell1 (train set) | 1.18 GB |Encoder: Wav2vec2.0, Decoder: CTC, Decoding method: Greedy search | 0.0510 | - | - |
[Hubert-large-lv60 Model](https://paddlespeech.bj.bcebos.com/hubert/hubert-large-lv60.pdparams) | hubert | LV-60k Dataset | - | 1.18 GB |Pre-trained hubert Model | - | - | - |
[Hubert-large-100h-librispeech Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr4/hubertASR-large-100h-librispeech_ckpt_1.4.0.model.tar.gz) | hubert | LV-60k Dataset | librispeech train-clean-100 | 1.27 GB |Encoder: Hubert, Decoder: Linear + CTC, Decoding method: Greedy search | - | 0.0587 | [HubertASR Librispeech ASR4](../../examples/librispeech/asr4) |
### Whisper Model
Demo Link | Training Data | Size | Descriptions | CER | Model

@ -10,6 +10,4 @@ export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=wav2vec2
export BIN_DIR=${MAIN_ROOT}/paddlespeech/s2t/exps/${MODEL}/bin
export BIN_DIR=${MAIN_ROOT}/paddlespeech/s2t/exps/wav2vec2/bin

@ -0,0 +1,197 @@
# Hubert2ASR with Librispeech
This example contains code used to finetune [hubert](https://arxiv.org/abs/2106.07447) model with [Librispeech dataset](http://www.openslr.org/resources/12)
## Overview
All the scripts you need are in `run.sh`. There are several stages in `run.sh`, and each stage has its function.
| Stage | Function |
|:---- |:----------------------------------------------------------- |
| 0 | Process data. It includes: <br> (1) Download the dataset <br> (2) Calculate the CMVN of the train dataset <br> (3) Get the vocabulary file <br> (4) Get the manifest files of the train, development and test dataset<br> (5) Download the pretrained wav2vec2 model |
| 1 | Train the model |
| 2 | Get the final model by averaging the top-k models, set k = 1 means to choose the best model |
| 3 | Test the final model performance |
| 4 | Infer the single audio file |
You can choose to run a range of stages by setting `stage` and `stop_stage `.
For example, if you want to execute the code in stage 2 and stage 3, you can run this script:
```bash
bash run.sh --stage 2 --stop_stage 3
```
Or you can set `stage` equal to `stop-stage` to only run one stage.
For example, if you only want to run `stage 0`, you can use the script below:
```bash
bash run.sh --stage 0 --stop_stage 0
```
The document below will describe the scripts in `run.sh` in detail.
## The Environment Variables
The path.sh contains the environment variables.
```bash
. ./path.sh
. ./cmd.sh
```
This script needs to be run first. And another script is also needed:
```bash
source ${MAIN_ROOT}/utils/parse_options.sh
```
It will support the way of using `--variable value` in the shell scripts.
## The Local Variables
Some local variables are set in `run.sh`.
`gpus` denotes the GPU number you want to use. If you set `gpus=`, it means you only use CPU.
`stage` denotes the number of stages you want to start from in the experiments.
`stop stage` denotes the number of the stage you want to end at in the experiments.
`conf_path` denotes the config path of the model.
`avg_num` denotes the number K of top-K models you want to average to get the final model.
`audio file` denotes the file path of the single file you want to infer in stage 5
`ckpt` denotes the checkpoint prefix of the model, e.g. "hubertASR"
You can set the local variables (except `ckpt`) when you use `run.sh`
For example, you can set the `gpus` and `avg_num` when you use the command line:
```bash
bash run.sh --gpus 0,1 --avg_num 20
```
## Stage 0: Data Processing
To use this example, you need to process data firstly and you can use stage 0 in `run.sh` to do this. The code is shown below:
```bash
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
bash ./local/data.sh || exit -1
fi
```
Stage 0 is for processing the data.
If you only want to process the data. You can run
```bash
bash run.sh --stage 0 --stop_stage 0
```
You can also just run these scripts in your command line.
```bash
. ./path.sh
. ./cmd.sh
bash ./local/data.sh
```
After processing the data, the `data` directory will look like this:
```bash
data/
|-- dev.meta
|-- lang_char
| `-- bpe_unigram_5000.model
| `-- bpe_unigram_5000.vocab
| `-- vocab.txt
|-- manifest.dev
|-- manifest.dev.raw
|-- manifest.test
|-- manifest.test.raw
|-- manifest.train
|-- manifest.train.raw
|-- mean_std.json
|-- test.meta
`-- train.meta
```
Stage 0 also downloads the pre-trained [hubert](https://paddlespeech.bj.bcebos.com/hubert/hubert-large-lv60.pdparams) model.
```bash
mkdir -p exp/hubert
wget -P exp/hubert https://paddlespeech.bj.bcebos.com/hubert/hubert-large-lv60.pdparams
```
## Stage 1: Model Training
If you want to train the model. you can use stage 1 in `run.sh`. The code is shown below.
```bash
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt}
fi
```
If you want to train the model, you can use the script below to execute stage 0 and stage 1:
```bash
bash run.sh --stage 0 --stop_stage 1
```
or you can run these scripts in the command line (only use CPU).
```bash
. ./path.sh
. ./cmd.sh
bash ./local/data.sh
CUDA_VISIBLE_DEVICES= ./local/train.sh conf/hubertASR.yaml hubertASR
```
## Stage 2: Top-k Models Averaging
After training the model, we need to get the final model for testing and inference. In every epoch, the model checkpoint is saved, so we can choose the best model from them based on the validation loss or we can sort them and average the parameters of the top-k models to get the final model. We can use stage 2 to do this, and the code is shown below. Note: We only train one epoch for hubertASR, thus the `avg_num` is set to 1.
```bash
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model
avg.sh best exp/${ckpt}/checkpoints ${avg_num}
fi
```
The `avg.sh` is in the `../../../utils/` which is define in the `path.sh`.
If you want to get the final model, you can use the script below to execute stage 0, stage 1, and stage 2:
```bash
bash run.sh --stage 0 --stop_stage 2
```
or you can run these scripts in the command line (only use CPU).
```bash
. ./path.sh
. ./cmd.sh
bash ./local/data.sh
CUDA_VISIBLE_DEVICES= ./local/train.sh conf/hubertASR.yaml hubertASR
avg.sh best exp/hubertASR/checkpoints 1
```
## Stage 3: Model Testing
The test stage is to evaluate the model performance. The code of test stage is shown below:
```bash
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
```
If you want to train a model and test it, you can use the script below to execute stage 0, stage 1, stage 2, and stage 3 :
```bash
bash run.sh --stage 0 --stop_stage 3
```
or you can run these scripts in the command line (only use CPU).
```bash
. ./path.sh
. ./cmd.sh
bash ./local/data.sh
CUDA_VISIBLE_DEVICES= ./local/train.sh conf/hubertASR.yaml hubertASR
avg.sh best exp/hubertASR/checkpoints 1
CUDA_VISIBLE_DEVICES= ./local/test.sh conf/hubertASR.yaml conf/tuning/decode.yaml exp/hubertASR/checkpoints/avg_1
```
## Pretrained Model
You can get the pretrained hubertASR from [this](../../../docs/source/released_model.md).
using the `tar` scripts to unpack the model and then you can use the script to test the model.
For example:
```bash
wget https://paddlespeech.bj.bcebos.com/hubert/hubertASR-large-100h-librispeech_ckpt_1.4.0.model.tar.gz
tar xzvf hubertASR-large-100h-librispeech_ckpt_1.4.0.model.tar.gz
source path.sh
# If you have process the data and get the manifest file you can skip the following 2 steps
bash local/data.sh --stage -1 --stop_stage -1
bash local/data.sh --stage 2 --stop_stage 2
CUDA_VISIBLE_DEVICES= ./local/test.sh conf/hubertASR.yaml conf/tuning/decode.yaml exp/hubertASR/checkpoints/avg_1
```
The performance of the released models are shown in [here](./RESULTS.md).
## Stage 4: Single Audio File Inference
In some situations, you want to use the trained model to do the inference for the single audio file. You can use stage 5. The code is shown below
```bash
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# test a single .wav file
CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1
fi
```
you can train the model by yourself using ```bash run.sh --stage 0 --stop_stage 3```, or you can download the pretrained model through the script below:
```bash
wget https://paddlespeech.bj.bcebos.com/hubert/hubertASR-large-100h-librispeech_ckpt_1.4.0.model.tar.gz
tar xzvf hubertASR-large-100h-librispeech_ckpt_1.4.0.model.tar.gz
```
You can download the audio demo:
```bash
wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/en/demo_002_en.wav -P data/
```
You need to prepare an audio file or use the audio demo above, please confirm the sample rate of the audio is 16K. You can get the result of the audio demo by running the script below.
```bash
CUDA_VISIBLE_DEVICES= ./local/test_wav.sh conf/hubertASR.yaml conf/tuning/decode.yaml exp/hubertASR/checkpoints/avg_1 data/demo_002_en.wav
```

@ -0,0 +1,9 @@
# LibriSpeech
## hubertASR
Fintuning on train-clean-100
train: Epoch 3, 1*V100-32G, batchsize: 4, accum_grad: 8
| Model | Params | Config | Augmentation| Test set | Decode method | WER |
| --- | --- | --- | --- | --- | --- | --- |
| hubertASR | 326.16M | conf/hubertASR.yaml | spec_aug | test-clean | greedy search | 0.05868 |

@ -0,0 +1,89 @@
# ====== About run.pl, queue.pl, slurm.pl, and ssh.pl ======
# Usage: <cmd>.pl [options] JOB=1:<nj> <log> <command...>
# e.g.
# run.pl --mem 4G JOB=1:10 echo.JOB.log echo JOB
#
# Options:
# --time <time>: Limit the maximum time to execute.
# --mem <mem>: Limit the maximum memory usage.
# -max-jobs-run <njob>: Limit the number parallel jobs. This is ignored for non-array jobs.
# --num-threads <ngpu>: Specify the number of CPU core.
# --gpu <ngpu>: Specify the number of GPU devices.
# --config: Change the configuration file from default.
#
# "JOB=1:10" is used for "array jobs" and it can control the number of parallel jobs.
# The left string of "=", i.e. "JOB", is replaced by <N>(Nth job) in the command and the log file name,
# e.g. "echo JOB" is changed to "echo 3" for the 3rd job and "echo 8" for 8th job respectively.
# Note that the number must start with a positive number, so you can't use "JOB=0:10" for example.
#
# run.pl, queue.pl, slurm.pl, and ssh.pl have unified interface, not depending on its backend.
# These options are mapping to specific options for each backend and
# it is configured by "conf/queue.conf" and "conf/slurm.conf" by default.
# If jobs failed, your configuration might be wrong for your environment.
#
#
# The official documentation for run.pl, queue.pl, slurm.pl, and ssh.pl:
# "Parallelization in Kaldi": http://kaldi-asr.org/doc/queue.html
# =========================================================~
# Select the backend used by run.sh from "local", "sge", "slurm", or "ssh"
cmd_backend='local'
# Local machine, without any Job scheduling system
if [ "${cmd_backend}" = local ]; then
# The other usage
export train_cmd="run.pl"
# Used for "*_train.py": "--gpu" is appended optionally by run.sh
export cuda_cmd="run.pl"
# Used for "*_recog.py"
export decode_cmd="run.pl"
# "qsub" (SGE, Torque, PBS, etc.)
elif [ "${cmd_backend}" = sge ]; then
# The default setting is written in conf/queue.conf.
# You must change "-q g.q" for the "queue" for your environment.
# To know the "queue" names, type "qhost -q"
# Note that to use "--gpu *", you have to setup "complex_value" for the system scheduler.
export train_cmd="queue.pl"
export cuda_cmd="queue.pl"
export decode_cmd="queue.pl"
# "sbatch" (Slurm)
elif [ "${cmd_backend}" = slurm ]; then
# The default setting is written in conf/slurm.conf.
# You must change "-p cpu" and "-p gpu" for the "partion" for your environment.
# To know the "partion" names, type "sinfo".
# You can use "--gpu * " by default for slurm and it is interpreted as "--gres gpu:*"
# The devices are allocated exclusively using "${CUDA_VISIBLE_DEVICES}".
export train_cmd="slurm.pl"
export cuda_cmd="slurm.pl"
export decode_cmd="slurm.pl"
elif [ "${cmd_backend}" = ssh ]; then
# You have to create ".queue/machines" to specify the host to execute jobs.
# e.g. .queue/machines
# host1
# host2
# host3
# Assuming you can login them without any password, i.e. You have to set ssh keys.
export train_cmd="ssh.pl"
export cuda_cmd="ssh.pl"
export decode_cmd="ssh.pl"
# This is an example of specifying several unique options in the JHU CLSP cluster setup.
# Users can modify/add their own command options according to their cluster environments.
elif [ "${cmd_backend}" = jhu ]; then
export train_cmd="queue.pl --mem 2G"
export cuda_cmd="queue-freegpu.pl --mem 2G --gpu 1 --config conf/gpu.conf"
export decode_cmd="queue.pl --mem 4G"
else
echo "$0: Error: Unknown cmd_backend=${cmd_backend}" 1>&2
return 1
fi

@ -0,0 +1,77 @@
{
"_name_or_path": "facebook/hubert-large-ll60k",
"activation_dropout": 0.0,
"apply_spec_augment": true,
"architectures": [
"HubertModel"
],
"attention_dropout": 0.1,
"bos_token_id": 1,
"conv_bias": true,
"conv_dim": [
512,
512,
512,
512,
512,
512,
512
],
"conv_kernel": [
10,
3,
3,
3,
3,
2,
2
],
"conv_stride": [
5,
2,
2,
2,
2,
2,
2
],
"ctc_loss_reduction": "sum",
"ctc_zero_infinity": false,
"do_stable_layer_norm": true,
"eos_token_id": 2,
"feat_extract_activation": "gelu",
"feat_extract_dropout": 0.0,
"feat_extract_norm": "layer",
"feat_proj_dropout": 0.1,
"final_dropout": 0.0,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-05,
"layerdrop": 0.1,
"mask_channel_length": 10,
"mask_channel_min_space": 1,
"mask_channel_other": 0.0,
"mask_channel_prob": 0.0,
"mask_channel_selection": "static",
"mask_feature_length": 10,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_min_space": 1,
"mask_time_other": 0.0,
"mask_time_prob": 0.075,
"mask_time_selection": "static",
"model_type": "hubert",
"num_attention_heads": 16,
"num_conv_pos_embedding_groups": 16,
"num_conv_pos_embeddings": 128,
"num_feat_extract_layers": 7,
"num_hidden_layers": 24,
"pad_token_id": 0,
"transformers_version": "4.10.0.dev0",
"vocab_size": 32,
"tokenizer_class": "Wav2Vec2CTCTokenizer"
}

@ -0,0 +1,142 @@
############################################
# Network Architecture #
############################################
freeze_hubert: False
normalize_wav: True
output_norm: True
init_type: kaiming_uniform # !Warning: need to convergence
enc:
input_shape: 1024
dnn_blocks: 2
dnn_neurons: 1024
activation: True
ctc:
enc_n_units: 1024
blank_id: 0
dropout_rate: 0.0
hubert_params_path: "exp/hubert/hubert-large-lv60.pdparams"
task_cfg:
label_rate: 50.0
sample_rate: 16000
normalize: True
enable_padding: False
max_keep_size: None
max_sample_size: 250000
min_sample_size: 32000
single_target: False
random_crop: True
pad_audio: False
model_cfg:
dropout_input: 0.0
final_dropout: 0.0
dropout: 0.0
attention_dropout: 0.0
activation_dropout: 0.1
apply_mask: True
mask_length: 10
mask_prob: 0.5
mask_selection: static
mask_other: 0.0
no_mask_overlap: False
mask_channel_length: 64
mask_channel_prob: 0.25
mask_channel_selection: static
mask_channel_other: 0.0
no_mask_channel_overlap: False
feature_grad_mult: 0.0
layerdrop: 0.1
normalize: True
fp16: True
label_rate: 50
extractor_mode: layer_norm
encoder_layers: 24
encoder_embed_dim: 1024
encoder_ffn_embed_dim: 4096
encoder_attention_heads: 16
activation_fn: gelu
encoder_layerdrop: 0.1
dropout_features: 0.0
final_dim: 768
untie_final_proj: True
layer_norm_first: True
conv_feature_layers: "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2"
conv_bias: False
logit_temp: 0.1
target_glu: False
mask_min_space: 1
mask_channel_min_space: 1
conv_pos: 128
conv_pos_groups: 16
latent_temp: [2.0, 0.5, 0.999995]
skip_masked: False
skip_nomask: True
###########################################
# Data #
###########################################
train_manifest: data/manifest.train-clean-100
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test-clean
###########################################
# Dataloader #
###########################################
vocab_filepath: data/lang_char/vocab.txt
unit_type: char
mean_std_filepath: ""
preprocess_config: conf/preprocess.yaml
sortagrad: -1 # Feed samples from shortest to longest ; -1: enabled for all epochs 0: disabled other: enabled for other epochs
batch_size: 4 # Different batch_size may cause large differences in results
maxlen_in: 1500 # if input length > maxlen-in batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out batchsize is automatically reduced
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
num_workers: 0
subsampling_factor: 1
num_encs: 1
dist_sampler: True
shortest_first: True
return_lens_rate: True
############################################
# Data Augmentation #
############################################
audio_augment: # for raw audio
sample_rate: 16000
speeds: [95, 100, 105]
###########################################
# Training #
###########################################
n_epoch: 3
accum_grad: 8
global_grad_clip: 5.0
model_optim: adadelta
model_optim_conf:
lr: 1.0
epsilon: 1.0e-6
rho: 0.95
model_scheduler: constantlr
model_scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
hubert_optim: adadelta
hubert_optim_conf:
lr: 0.95
epsilon: 1.0e-6
rho: 0.95
hubert_scheduler: constantlr
hubert_scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 1
checkpoint:
kbest_n: 50
latest_n: 5

@ -0,0 +1,3 @@
process:
# use raw audio
- type: wav_process

@ -0,0 +1,9 @@
{
"do_normalize": true,
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
"feature_size": 1,
"padding_side": "right",
"padding_value": 0,
"return_attention_mask": true,
"sampling_rate": 16000
}

@ -0,0 +1,4 @@
decode_batch_size: 1
error_rate_type: wer
decoding_method: ctc_greedy_search # 'ctc_greedy_search', 'ctc_prefix_beam_search'
beam_size: 10

@ -0,0 +1,110 @@
#!/bin/bash
stage=-1
stop_stage=100
unit_type=char
dict_dir=data/lang_char
source ${MAIN_ROOT}/utils/parse_options.sh
mkdir -p data
mkdir -p ${dict_dir}
TARGET_DIR=${MAIN_ROOT}/dataset
mkdir -p ${TARGET_DIR}
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
# download data, generate manifests
python3 ${TARGET_DIR}/librispeech/librispeech.py \
--manifest_prefix="data/manifest" \
--target_dir="${TARGET_DIR}/librispeech" \
--full_download="True"
if [ $? -ne 0 ]; then
echo "Prepare LibriSpeech failed. Terminated."
exit 1
fi
for set in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
mv data/manifest.${set} data/manifest.${set}.raw
done
rm -rf data/manifest.train.raw data/manifest.dev.raw data/manifest.test.raw
for set in train-clean-100 train-clean-360 train-other-500; do
cat data/manifest.${set}.raw >> data/manifest.train.raw
done
for set in dev-clean dev-other; do
cat data/manifest.${set}.raw >> data/manifest.dev.raw
done
for set in test-clean test-other; do
cat data/manifest.${set}.raw >> data/manifest.test.raw
done
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# compute mean and stddev for normalizer
num_workers=$(nproc)
python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
--manifest_path="data/manifest.train.raw" \
--num_samples=2000 \
--spectrum_type="fbank" \
--feat_dim=161 \
--delta_delta=false \
--sample_rate=16000 \
--stride_ms=10 \
--window_ms=25 \
--use_dB_normalization=False \
--num_workers=${num_workers} \
--output_path="data/mean_std.json"
if [ $? -ne 0 ]; then
echo "Compute mean and stddev failed. Terminated."
exit 1
fi
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# build vocabulary
python3 ${MAIN_ROOT}/utils/build_vocab.py \
--unit_type ${unit_type} \
--count_threshold=0 \
--vocab_path="${dict_dir}/vocab.txt" \
--manifest_paths="data/manifest.train.raw"
if [ $? -ne 0 ]; then
echo "Build vocabulary failed. Terminated."
exit 1
fi
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# format manifest with tokenids, vocab size
for set in train dev test dev-clean dev-other test-clean test-other; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \
--cmvn_path "data/mean_std.json" \
--unit_type ${unit_type} \
--vocab_path="${dict_dir}/vocab.txt" \
--manifest_path="data/manifest.${set}.raw" \
--output_path="data/manifest.${set}"
if [ $? -ne 0 ]; then
echo "Formt mnaifest.${set} failed. Terminated."
exit 1
fi
}&
done
wait
fi
echo "LibriSpeech Data preparation done."
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
mkdir -p exp/hubert
echo "Pretrained hubert model download"
wget -P exp/hubert https://paddlespeech.bj.bcebos.com/hubert/hubert-large-lv60.pdparams
fi
exit 0

@ -0,0 +1,83 @@
#!/bin/bash
set -e
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
expdir=exp
datadir=data
recog_set="test-clean test-other dev-clean dev-other"
recog_set="test-clean"
config_path=$1
decode_config_path=$2
ckpt_prefix=$3
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
# download language model
#bash local/download_lm_en.sh
#if [ $? -ne 0 ]; then
# exit 1
#fi
python3 utils/format_rsl.py \
--origin_ref data/manifest.test-clean.raw \
--trans_ref data/manifest.test-clean.text
for type in ctc_greedy_search; do
echo "decoding ${type}"
batch_size=16
python3 -u ${BIN_DIR}/test.py \
--ngpu ${ngpu} \
--config ${config_path} \
--decode_cfg ${decode_config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decode.decoding_method ${type} \
--opts decode.decode_batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
python3 utils/format_rsl.py \
--origin_hyp ${ckpt_prefix}.${type}.rsl \
--trans_hyp ${ckpt_prefix}.${type}.rsl.text
python3 utils/compute-wer.py --char=1 --v=1 \
data/manifest.test-clean.text ${ckpt_prefix}.${type}.rsl.text > ${ckpt_prefix}.${type}.error
echo "decoding ${type} done."
done
for type in ctc_prefix_beam_search; do
echo "decoding ${type}"
batch_size=1
python3 -u ${BIN_DIR}/test.py \
--ngpu ${ngpu} \
--config ${config_path} \
--decode_cfg ${decode_config_path} \
--result_file ${ckpt_prefix}.${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decode.decoding_method ${type} \
--opts decode.decode_batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
python3 utils/format_rsl.py \
--origin_hyp ${ckpt_prefix}.${type}.rsl \
--trans_hyp ${ckpt_prefix}.${type}.rsl.text
python3 utils/compute-wer.py --char=1 --v=1 \
data/manifest.test-clean.text ${ckpt_prefix}.${type}.rsl.text > ${ckpt_prefix}.${type}.error
echo "decoding ${type} done."
done
echo "Finished"
exit 0

@ -0,0 +1,58 @@
#!/bin/bash
if [ $# != 4 ];then
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix audio_file"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
decode_config_path=$2
ckpt_prefix=$3
audio_file=$4
mkdir -p data
wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/en/demo_002_en.wav -P data/
if [ $? -ne 0 ]; then
exit 1
fi
if [ ! -f ${audio_file} ]; then
echo "Plase input the right audio_file path"
exit 1
fi
chunk_mode=false
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
chunk_mode=true
fi
# download language model
#bash local/download_lm_ch.sh
#if [ $? -ne 0 ]; then
# exit 1
#fi
for type in ctc_greedy_search; do
echo "decoding ${type}"
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
python3 -u ${BIN_DIR}/test_wav.py \
--ngpu ${ngpu} \
--config ${config_path} \
--decode_cfg ${decode_config_path} \
--result_file ${output_dir}/${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decode.decoding_method ${type} \
--opts decode.decode_batch_size ${batch_size} \
--audio_file ${audio_file}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
exit 0

@ -0,0 +1,58 @@
#!/bin/bash
if [ $# -lt 2 ] && [ $# -gt 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
ckpt_name=$2
resume=$3
ips=$4
if [ ! $ips ];then
ips_config=
else
ips_config="--ips="${ips}
fi
mkdir -p exp
# seed may break model convergence
seed=1988
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True
fi
# export FLAGS_cudnn_exhaustive_search=true
# export FLAGS_conv_workspace_size_limit=4000
export FLAGS_allocator_strategy=naive_best_fit
if [ ${ngpu} == 0 ]; then
python3 -u ${BIN_DIR}/train.py \
--ngpu ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name} \
--seed ${seed} \
--resume ${resume}
else
python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
--ngpu ${ngpu} \
--config ${config_path} \
--output exp/${ckpt_name} \
--seed ${seed} \
--resume ${resume}
fi
if [ ${seed} != 0 ]; then
unset FLAGS_cudnn_deterministic
fi
if [ $? -ne 0 ]; then
echo "Failed in training!"
exit 1
fi
exit 0

@ -0,0 +1,13 @@
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/tools/sctk/bin:${PWD}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
export BIN_DIR=${MAIN_ROOT}/paddlespeech/s2t/exps/hubert/bin

@ -0,0 +1,47 @@
#!/bin/bash
set -e
. ./path.sh || exit 1;
. ./cmd.sh || exit 1;
gpus=0
stage=0
stop_stage=0
conf_path=conf/hubertASR.yaml
ips= #xx.xx.xx.xx,xx.xx.xx.xx
decode_conf_path=conf/tuning/decode.yaml
avg_num=1
resume= # xx e.g. 30
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
audio_file=data/demo_002_en.wav
avg_ckpt=avg_${avg_num}
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
echo "checkpoint name ${ckpt}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
bash ./local/data.sh || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${resume} ${ips}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model
avg.sh best exp/${ckpt}/checkpoints ${avg_num}
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# greedy search decoder
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# test a single .wav file
CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1
fi

@ -25,9 +25,6 @@ import librosa
import numpy as np
import paddle
import soundfile
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.utils.utility import UpdateConfig
from yacs.config import CfgNode
from ...utils.env import MODEL_HOME
@ -37,6 +34,9 @@ from ..log import logger
from ..utils import CLI_TIMER
from ..utils import stats_wrapper
from ..utils import timer_register
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ['ASRExecutor']

@ -51,11 +51,8 @@ class SSLExecutor(BaseExecutor):
self.parser.add_argument(
'--model',
type=str,
default=None,
choices=[
tag[:tag.index('-')]
for tag in self.task_resource.pretrained_models.keys()
],
default='wav2vec2',
choices=['wav2vec2', 'hubert'],
help='Choose model type of asr task.')
self.parser.add_argument(
'--task',
@ -67,7 +64,7 @@ class SSLExecutor(BaseExecutor):
'--lang',
type=str,
default='en',
help='Choose model language. zh or en, zh:[wav2vec2ASR_aishell1-zh-16k], en:[wav2vec2ASR_librispeech-en-16k]'
help='Choose model language. zh or en, zh:[wav2vec2ASR_aishell1-zh-16k], en:[wav2vec2ASR_librispeech-en-16k, hubertASR_librispeech_100-en-16k]'
)
self.parser.add_argument(
"--sample_rate",
@ -137,13 +134,6 @@ class SSLExecutor(BaseExecutor):
logger.debug("start to init the model")
if model_type is None:
if lang == 'en':
model_type = 'wav2vec2ASR_librispeech'
elif lang == 'zh':
model_type = 'wav2vec2ASR_aishell1'
else:
logger.error(
"invalid lang, please input --lang en or --lang zh")
logger.debug(
"Model type had not been specified, default {} was used.".
format(model_type))
@ -155,9 +145,20 @@ class SSLExecutor(BaseExecutor):
if cfg_path is None or ckpt_path is None:
sample_rate_str = '16k' if sample_rate == 16000 else '8k'
if task == 'asr':
tag = model_type + '-' + lang + '-' + sample_rate_str
if model_type == 'wav2vec2':
if lang == 'en':
model_prefix = 'wav2vec2ASR_librispeech'
elif lang == 'zh':
model_prefix = 'wav2vec2ASR_aishell1'
tag = model_prefix + '-' + lang + '-' + sample_rate_str
elif model_type == 'hubert':
if lang == 'en':
model_prefix = 'hubertASR_librispeech-100h'
elif lang == 'zh':
logger.error("zh hubertASR is not supported yet")
tag = model_prefix + '-' + lang + '-' + sample_rate_str
else:
tag = 'wav2vec2' + '-' + lang + '-' + sample_rate_str
tag = model_type + '-' + lang + '-' + sample_rate_str
self.task_resource.set_task_model(tag, version=None)
self.res_path = self.task_resource.res_dir
@ -184,16 +185,17 @@ class SSLExecutor(BaseExecutor):
self.text_feature = TextFeaturizer(
unit_type=self.config.unit_type,
vocab=self.config.vocab_filepath)
self.config.output_dim = len(self.config.vocab_filepath)
elif lang == 'zh':
self.text_feature = AutoTokenizer.from_pretrained(
self.config.tokenizer)
self.config.output_dim = self.text_feature.vocab_size
self.config.decode.decoding_method = decode_method
model_name = model_type[:model_type.rindex(
model_name = model_prefix[:model_prefix.rindex(
'_')] # model_type: {model_name}_{dataset}
else:
model_name = 'wav2vec2'
model_name = model_type
model_class = self.task_resource.get_model_class(model_name)
model_conf = self.config
model = model_class.from_config(model_conf)
self.model = model
@ -204,9 +206,9 @@ class SSLExecutor(BaseExecutor):
if task == 'asr':
self.model.set_state_dict(model_dict)
else:
self.model.wav2vec2.set_state_dict(model_dict)
getattr(self.model, model_type).set_state_dict(model_dict)
def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
def preprocess(self, input: Union[str, os.PathLike]):
"""
Input preprocess and return paddle.Tensor stored in self.input.
Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet).
@ -263,8 +265,7 @@ class SSLExecutor(BaseExecutor):
audio = self._inputs["audio"]
if task == 'asr':
cfg = self.config.decode
logger.debug(
f"we will use the wav2vec2ASR like model : {model_type}")
logger.debug(f"we will use the {model_type}ASR like model.")
try:
result_transcripts = self.model.decode(
audio,
@ -277,7 +278,8 @@ class SSLExecutor(BaseExecutor):
logger.exception(e)
else:
logger.debug(
"we will use the wav2vec2 like model to extract audio feature")
f"we will use the {model_type} like model to extract audio feature."
)
try:
out_feature = self.model(audio[:, :, 0])
self._outputs["result"] = out_feature[0]
@ -454,7 +456,7 @@ class SSLExecutor(BaseExecutor):
if rtf:
k = self.__class__.__name__
CLI_TIMER[k]['start'].append(time.time())
self.preprocess(model, audio_file)
self.preprocess(audio_file)
self.infer(model, task)
res = self.postprocess() # Retrieve result of asr.

@ -23,6 +23,8 @@ model_alias = {
# ---------------------------------
"wav2vec2ASR": ["paddlespeech.s2t.models.wav2vec2:Wav2vec2ASR"],
"wav2vec2": ["paddlespeech.s2t.models.wav2vec2:Wav2vec2Base"],
"hubertASR": ["paddlespeech.s2t.models.hubert:HubertASR"],
"hubert": ["paddlespeech.s2t.models.hubert:HubertBase"],
# ---------------------------------
# -------------- ASR --------------

@ -117,6 +117,38 @@ ssl_dynamic_pretrained_models = {
'exp/wav2vec2ASR/checkpoints/avg_1.pdparams',
},
},
"hubert-en-16k": {
'1.4': {
'url':
'https://paddlespeech.bj.bcebos.com/hubert/hubert-large-lv60_ckpt_1.4.0.model.tar.gz',
'md5':
'efecfb87a8718aa9253b7459c1fe9b54',
'cfg_path':
'model.yaml',
'ckpt_path':
'hubert-large-lv60',
'model':
'hubert-large-lv60.pdparams',
'params':
'hubert-large-lv60.pdparams',
},
},
"hubertASR_librispeech-100h-en-16k": {
'1.4': {
'url':
'https://paddlespeech.bj.bcebos.com/hubert/hubertASR-large-100h-librispeech_ckpt_1.4.0.model.tar.gz',
'md5':
'574cefd11aaef5737969ce22a7f33ea2',
'cfg_path':
'model.yaml',
'ckpt_path':
'exp/hubertASR/checkpoints/avg_1',
'model':
'exp/hubertASR/checkpoints/avg_1.pdparams',
'params':
'exp/hubertASR/checkpoints/avg_1.pdparams',
},
},
}
# ---------------------------------

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,64 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation for hubert model."""
import cProfile
from yacs.config import CfgNode
from paddlespeech.s2t.exps.hubert.model import HubertASRTester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments
def main_sp(config, args):
exp = Tester(config, args)
with exp.eval():
exp.setup()
exp.run_test()
def main(config, args):
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
'--dict-path', type=str, default=None, help='dict path.')
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
# Setting for profiling
pr = cProfile.Profile()
pr.runcall(main, config, args)
pr.dump_stats('test.profile')

@ -0,0 +1,118 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation for hubert model."""
import os
import sys
from pathlib import Path
import paddle
import soundfile
from yacs.config import CfgNode
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.models.hubert.hubert_ASR import HubertASR
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import UpdateConfig
logger = Log(__name__).getlog()
class HubertInfer():
def __init__(self, config, args):
self.args = args
self.config = config
self.audio_file = args.audio_file
self.text_feature = TextFeaturizer(
unit_type=config.unit_type, vocab=config.vocab_filepath)
paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu')
# model
model_conf = config
with UpdateConfig(model_conf):
model_conf.output_dim = self.text_feature.vocab_size
model = HubertASR.from_config(model_conf)
self.model = model
self.model.eval()
# load model
params_path = self.args.checkpoint_path + ".pdparams"
model_dict = paddle.load(params_path)
self.model.set_state_dict(model_dict)
def run(self):
check(args.audio_file)
with paddle.no_grad():
# read
audio, _ = soundfile.read(
self.audio_file, dtype="int16", always_2d=True)
logger.info(f"audio shape: {audio.shape}")
xs = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0)
decode_config = self.config.decode
result_transcripts, result_tokenids = self.model.decode(
xs,
text_feature=self.text_feature,
decoding_method=decode_config.decoding_method,
beam_size=decode_config.beam_size)
rsl = result_transcripts[0]
utt = Path(self.audio_file).name
logger.info(f"hyp: {utt} {rsl}")
return rsl
def check(audio_file):
if not os.path.isfile(audio_file):
print("Please input the right audio file path")
sys.exit(-1)
logger.info("checking the audio file format......")
try:
sig, sample_rate = soundfile.read(audio_file)
except Exception as e:
logger.error(str(e))
logger.error(
"can not open the wav file, please check the audio file format")
sys.exit(-1)
logger.info("The sample rate is %d" % sample_rate)
assert (sample_rate == 16000)
logger.info("The audio file format is right")
def main(config, args):
HubertInfer(config, args).run()
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
parser.add_argument(
"--audio_file", type=str, help="path of the input audio file")
args = parser.parse_args()
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.decode_cfg:
decode_confs = CfgNode(new_allowed=True)
decode_confs.merge_from_file(args.decode_cfg)
config.decode = decode_confs
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
main(config, args)

@ -0,0 +1,55 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trainer for hubert model."""
import cProfile
import os
from yacs.config import CfgNode
from paddlespeech.s2t.exps.hubert.model import HubertASRTrainer as Trainer
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments
def main_sp(config, args):
exp = Trainer(config, args)
exp.setup()
exp.run()
def main(config, args):
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument(
'--resume', type=str, default="", nargs="?", help='resume ckpt path.')
args = parser.parse_args()
print_arguments(args, globals())
# https://yaml.org/type/float.html
config = CfgNode(new_allowed=True)
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
# Setting for profiling
pr = cProfile.Profile()
pr.runcall(main, config, args)
pr.dump_stats(os.path.join(args.output, 'train.profile'))

@ -0,0 +1,918 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains hubert model."""
import json
import math
import os
import re
import time
from collections import OrderedDict
from contextlib import nullcontext
import jsonlines
import numpy as np
import paddle
from hyperpyyaml import load_hyperpyyaml
from paddle import distributed as dist
from paddlenlp.transformers import AutoTokenizer
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import DataLoaderFactory
from paddlespeech.s2t.io.speechbrain import data_pipeline
from paddlespeech.s2t.io.speechbrain import dataio
from paddlespeech.s2t.io.speechbrain import dataset
from paddlespeech.s2t.io.speechbrain.dataloader import make_dataloader
from paddlespeech.s2t.models.hubert.hubert_ASR import HubertASR
from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugment
from paddlespeech.s2t.training.optimizer import OptimizerFactory
from paddlespeech.s2t.training.reporter import ObsScope
from paddlespeech.s2t.training.reporter import report
from paddlespeech.s2t.training.scheduler import LRSchedulerFactory
from paddlespeech.s2t.training.timer import Timer
from paddlespeech.s2t.training.trainer import Trainer
from paddlespeech.s2t.utils import error_rate
from paddlespeech.s2t.utils import layer_tools
from paddlespeech.s2t.utils import mp_tools
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import UpdateConfig
logger = Log(__name__).getlog()
# Todo: change this when paddle supports this api
def clip_grad_norm_(
parameters,
max_norm,
norm_type=2.0,
error_if_nonfinite=False, ):
r"""Clips gradient norm of the iteratable parameters.
Norms are calculated together on all gradients, just as they are
connected into one vector. The gradient will be modified in place.
This API can only run in dynamic graph mode, not static graph mode.
Args:
parameters (Iterable[paddle.Tensor] or paddle.Tensor): Tensors or a single Tensor
that will be normalized gradients
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be `inf` for
infinity norm.
error_if_nonfinite (bool): if True, throw an error if the total
norm of the gradients from :attr:`parameters` is `nan`,
`inf`, or `-inf`.
Returns:
Total norm of the parameter gradients (treated as a single vector).
Example:
.. code-block:: python
import paddle
x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32')
max_norm = float(5.0)
linear = paddle.nn.Linear(in_features=10, out_features=10)
out = linear(x)
loss = paddle.mean(out)
loss.backward()
paddle.nn.utils.clip_grad_norm_(linear.parameters(), max_norm)
sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters())
sdg.step()
"""
if not paddle.in_dynamic_mode():
raise RuntimeError('this API can only run in dynamic mode.')
if isinstance(parameters, paddle.Tensor):
parameters = [parameters]
support_norm_type = [float("inf"), 0, 1, 2]
if norm_type not in support_norm_type:
raise ValueError(f'norm_type only support {support_norm_type}')
grads = [p.grad for p in parameters if p.grad is not None]
max_norm = float(max_norm)
norm_type = float(norm_type)
if len(grads) == 0:
return paddle.to_tensor(0.0)
if norm_type == float("inf"):
norms = [g.detach().abs().max() for g in grads]
total_norm = (norms[0]
if len(norms) == 1 else paddle.max(paddle.stack(norms)))
else:
total_norm = paddle.linalg.norm(
paddle.stack(
[paddle.linalg.norm(g.detach(), norm_type) for g in grads]),
norm_type, )
if error_if_nonfinite and paddle.logical_or(total_norm.isnan(),
total_norm.isinf()):
raise RuntimeError(
f'The total norm of {norm_type} order of the gradients from '
'`parameters` is non-finite, so it cannot be clipped. In any case, '
'disable this error and scale the gradient by non-finite norm, '
'set `error_if_nonfinite=False`')
clip_coef = max_norm / (total_norm + 1e-6)
# Note: when the coef is clamped to 1, it is redundant to multiply the clamped coef, but this
# avoids the `if clip_coef < 1:` condition.
clip_coef_clamped = paddle.clip(clip_coef, max=1.0)
with paddle.no_grad():
for _, p in enumerate(parameters):
g = p.grad
if g is not None:
p.grad = paddle.multiply(x=g, y=clip_coef_clamped)
return total_norm
class HubertASRTrainer(Trainer):
def __init__(self, config, args):
super().__init__(config, args)
self.avg_train_loss = 0.0
self.loss_isfinite = True # while flag is 'False', loss in Nan or inf, and can not be avg
self.use_sb = True # whether use speech brain dataloader
def update_average(self, batch_index, loss):
"""Update running average of the loss.
Arguments
---------
batch_index : int
current batch index
loss : paddle.tensor
detached loss, a single float value.
"""
if math.isfinite(loss):
self.avg_train_loss -= self.avg_train_loss / (batch_index + 1)
self.avg_train_loss += loss / (batch_index + 1)
else:
self.loss_isfinite = False
logger.info('loss:{} in Nan or inf, error'.format(loss))
def before_train(self):
from_scratch = self.resume_or_scratch()
if from_scratch:
# scratch: save init model, i.e. 0 epoch
self.save(tag='init', infos=None)
else:
# resume: train next_epoch and next_iteration
self.epoch += 1
logger.info(
f"Resume train: epoch {self.epoch }, step {self.iteration}!")
self.maybe_batch_sampler_step()
def train_batch(self, batch_index, batch, msg):
train_conf = self.config
start = time.time()
# forward
## sb data pipeline
if self.use_sb:
wav, wavs_lens_rate = batch['sig']
target, target_lens_rate = batch['tokens']
target_lens = (target_lens_rate *
target.shape[1]).round().astype(paddle.int64)
else:
utt, wav, wavs_lens, target, target_lens = batch
wavs_lens_rate = wavs_lens / wav.shape[1]
wav = wav[:, :, 0]
logger.info('training utt ids: {}'.format(utt))
if hasattr(train_conf, 'audio_augment'):
wav = self.speech_augmentation(wav, wavs_lens_rate)
loss = self.model(wav, wavs_lens_rate, target, target_lens)
# loss div by `batch_size * accum_grad`
loss /= train_conf.accum_grad
# update self.avg_train_loss
self.update_average(batch_index, float(loss))
# loss backward
if (batch_index + 1) % train_conf.accum_grad != 0:
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
# When using cpu w/o DDP, model does not have `no_sync`
context = self.model.no_sync if (hasattr(self.model, "no_sync") and
self.parallel) else nullcontext
else:
# Used for single gpu training and DDP gradient synchronization
# processes.
context = nullcontext
with context():
loss.backward()
layer_tools.print_grads(self.model, print_func=None)
# optimizer step old
if (batch_index + 1) % train_conf.accum_grad == 0:
#do global grad clip
if train_conf.global_grad_clip != 0:
clip_grad_norm_(self.model.parameters(),
train_conf.global_grad_clip)
self.model_optimizer.step()
self.model_optimizer.clear_grad()
if not train_conf.freeze_hubert:
self.hubert_optimizer.step()
self.hubert_optimizer.clear_grad()
if self.config.model_scheduler != 'newbobscheduler':
self.model_lr_scheduler.step()
if self.config.hubert_scheduler != 'newbobscheduler':
if not train_conf.freeze_hubert:
self.hubert_lr_scheduler.step()
self.iteration += 1
losses_np = {'loss': self.avg_train_loss * train_conf.accum_grad}
iteration_time = time.time() - start
for k, v in losses_np.items():
report(k, v)
report("loss_whitoutavg", float(loss))
report("batch_size", self.config.batch_size)
report("accum", train_conf.accum_grad)
report("step_cost", iteration_time)
if (batch_index + 1) % train_conf.accum_grad == 0:
if dist.get_rank() == 0 and self.visualizer:
losses_np_v = losses_np.copy()
losses_np_v.update({
"model_lr": self.model_lr_scheduler(),
"hubert_lr": self.hubert_lr_scheduler()
})
for key, val in losses_np_v.items():
self.visualizer.add_scalar(
tag='train/' + key, value=val, step=self.iteration - 1)
@paddle.no_grad()
def valid(self):
self.model.eval()
if not self.use_streamdata:
logger.info(
f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = {}
step = 0
total_loss = 0.0
num_seen_utts = 1 # use update_average and no need for num_seen_utts here
for i, batch in enumerate(self.valid_loader):
if self.use_sb:
wav, wavs_lens_rate = batch['sig']
target, target_lens_rate = batch['tokens']
target_lens = (target_lens_rate *
target.shape[1]).round().astype(paddle.int64)
else:
utt, wav, wavs_lens, target, target_lens = batch
wavs_lens_rate = wavs_lens / wav.shape[1]
wav = wav[:, :, 0]
loss = self.model(wav, wavs_lens_rate, target, target_lens)
# use update_average
total_loss -= total_loss / (step + 1)
total_loss += loss / (step + 1)
if math.isfinite(float(loss)):
step += 1
valid_losses['val_loss'] = float(loss)
else:
logger.info('loss:{} in Nan or inf, error'.format(float(loss)))
if (i + 1) % self.config.log_interval == 0:
valid_losses['val_history_loss'] = float(total_loss)
# logging
msg = f"Valid: Rank: {dist.get_rank()}, "
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata:
msg += "batch: {}/{}, ".format(i + 1,
len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_losses.items())
logger.info(msg)
logger.info(
'Rank {} Val info val_loss {}'.format(dist.get_rank(), total_loss))
return total_loss, num_seen_utts
@mp_tools.rank_zero_only
def save(self, tag=None, infos: dict=None):
"""Save checkpoint (model parameters and optimizer states).
Args:
tag (int or str, optional): None for step, else using tag, e.g epoch. Defaults to None.
infos (dict, optional): meta data to save. Defaults to None.
"""
infos = infos if infos else dict()
infos.update({
"epoch": self.epoch,
"model_lr": self.model_optimizer.get_lr(),
"hubert_lr": self.hubert_optimizer.get_lr()
})
checkpoint_path = os.path.join(
self.checkpoint_dir,
"{}".format(self.iteration if tag is None else tag))
model_dict = self.model.state_dict()
params_path = checkpoint_path + ".pdparams"
paddle.save(model_dict, params_path)
logger.info("Saved model to {}".format(params_path))
model_opt_dict = self.model_optimizer.state_dict()
hubert_opt_dict = self.hubert_optimizer.state_dict()
opt_dict = {'model': model_opt_dict, 'hubert': hubert_opt_dict}
optimizer_path = checkpoint_path + ".pdopt"
paddle.save(opt_dict, optimizer_path)
logger.info("Saved optimzier state to {}".format(optimizer_path))
scheduler_dict = {}
if self.config.model_scheduler == 'newbobscheduler':
scheduler_dict['model'] = self.model_lr_scheduler.save()
if self.config.hubert_scheduler == 'newbobscheduler':
scheduler_dict['hubert'] = self.hubert_lr_scheduler.save()
if scheduler_dict:
scheduler_path = checkpoint_path + ".pdlrs"
paddle.save(scheduler_dict, scheduler_path)
logger.info("Saved scheduler state to {}".format(scheduler_path))
info_path = re.sub('.pdparams$', '.json', params_path)
infos = {} if infos is None else infos
with open(info_path, 'w', encoding='utf8') as fout:
data = json.dumps(infos)
fout.write(data)
def resume_or_scratch(self):
"""Resume from latest checkpoint at checkpoints in the output
directory or load a specified checkpoint.
If ``args.checkpoint_path`` is not None, load the checkpoint, else
resume training.
"""
scratch = None
if self.args.resume:
# just restore ckpt
# lr will resotre from optimizer ckpt
resume_json_path = os.path.join(self.checkpoint_dir,
self.args.resume + '.json')
with open(resume_json_path, 'r', encoding='utf8') as f:
resume_json = json.load(f)
self.iteration = 0
self.epoch = resume_json["epoch"]
# resotre model from *.pdparams
params_path = os.path.join(self.checkpoint_dir,
"{}".format(self.epoch)) + '.pdparams'
model_dict = paddle.load(params_path)
self.model.set_state_dict(model_dict)
# resotre optimizer from *.pdopt
optimizer_path = os.path.join(self.checkpoint_dir,
"{}".format(self.epoch)) + '.pdopt'
optimizer_dict = paddle.load(optimizer_path)
self.model_optimizer.set_state_dict(optimizer_dict['model'])
self.hubert_optimizer.set_state_dict(optimizer_dict['hubert'])
# resotre lr_scheduler from *.pdlrs
scheduler_path = os.path.join(self.checkpoint_dir,
"{}".format(self.epoch)) + '.pdlrs'
if os.path.isfile(os.path.join(scheduler_path)):
scheduler_dict = paddle.load(scheduler_path)
if self.config.model_scheduler == 'newbobscheduler':
self.model_lr_scheduler.load(scheduler_dict['model'])
if self.config.hubert_scheduler == 'newbobscheduler':
self.hubert_lr_scheduler.load(scheduler_dict['hubert'])
logger.info(
f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!")
scratch = False
else:
self.iteration = 0
self.epoch = 0
scratch = True
logger.info("Init from scratch!")
return scratch
def do_train(self):
"""The training process control by step."""
# !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine
# the code to satisfy the script export requirements
# script_model = paddle.jit.to_static(self.model)
# script_model_path = str(self.checkpoint_dir / 'init')
# paddle.jit.save(script_model, script_model_path)
self.before_train()
if not self.use_streamdata:
logger.info(
f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"):
self.model.train()
try:
data_start_time = time.time()
for batch_index, batch in enumerate(self.train_loader):
dataload_time = time.time() - data_start_time
msg = "Train:"
observation = OrderedDict()
with ObsScope(observation):
report("Rank", dist.get_rank())
report("epoch", self.epoch)
report('step', self.iteration)
report("model_lr", self.model_optimizer.get_lr())
report("hubert_lr", self.hubert_optimizer.get_lr())
self.train_batch(batch_index, batch, msg)
self.after_train_batch()
report('iter', batch_index + 1)
if not self.use_streamdata:
report('total', len(self.train_loader))
report('reader_cost', dataload_time)
observation['batch_cost'] = observation[
'reader_cost'] + observation['step_cost']
observation['samples'] = observation['batch_size']
observation['ips,samples/s'] = observation[
'batch_size'] / observation['batch_cost']
for k, v in observation.items():
msg += f" {k.split(',')[0]}: "
msg += f"{v:>.8f}" if isinstance(v,
float) else f"{v}"
msg += f" {k.split(',')[1]}" if len(
k.split(',')) == 2 else ""
msg += ","
msg = msg[:-1] # remove the last ","
if (batch_index + 1) % self.config.log_interval == 0:
logger.info(msg)
data_start_time = time.time()
except Exception as e:
logger.error(e)
raise e
with Timer("Eval Time Cost: {}"):
total_loss, num_seen_utts = self.valid()
if dist.get_world_size() > 1:
num_seen_utts = paddle.to_tensor(num_seen_utts)
dist.all_reduce(num_seen_utts)
total_loss = paddle.to_tensor(total_loss)
dist.all_reduce(total_loss)
cv_loss = total_loss / num_seen_utts
cv_loss = float(cv_loss)
else:
cv_loss = float(total_loss)
logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer:
self.visualizer.add_scalar(
tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/model_lr',
value=self.model_lr_scheduler(),
step=self.epoch)
self.visualizer.add_scalar(
tag='eval/hubert_lr',
value=self.hubert_lr_scheduler(),
step=self.epoch)
if self.config.model_scheduler == 'newbobscheduler':
self.model_lr_scheduler.step(cv_loss)
if self.config.hubert_scheduler == 'newbobscheduler':
if not self.config.freeze_hubert:
self.hubert_lr_scheduler.step(cv_loss)
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.avg_train_loss = 0.0
self.new_epoch()
def dataio_prepare(self, hparams):
"""This function prepares the datasets to be used in the brain class.
It also defines the data processing pipeline through user-defined functions."""
data_folder = hparams["data_folder"]
train_data = dataset.DynamicItemDataset.from_csv(
csv_path=hparams["train_data"],
replacements={"data_root": data_folder}, )
if hparams["sorting"] == "ascending":
# we sort training data to speed up training and get better results.
train_data = train_data.filtered_sorted(sort_key="duration")
# when sorting do not shuffle in dataloader ! otherwise is pointless
hparams["train_dataloader_opts"]["shuffle"] = False
elif hparams["sorting"] == "descending":
train_data = train_data.filtered_sorted(
sort_key="duration", reverse=True)
# when sorting do not shuffle in dataloader ! otherwise is pointless
hparams["train_dataloader_opts"]["shuffle"] = False
elif hparams["sorting"] == "random":
pass
else:
raise NotImplementedError(
"sorting must be random, ascending or descending")
valid_data = dataset.DynamicItemDataset.from_csv(
csv_path=hparams["valid_data"],
replacements={"data_root": data_folder}, )
valid_data = valid_data.filtered_sorted(sort_key="duration")
test_data = dataset.DynamicItemDataset.from_csv(
csv_path=hparams["test_data"],
replacements={"data_root": data_folder}, )
test_data = test_data.filtered_sorted(sort_key="duration")
datasets = [train_data, valid_data, test_data]
# Defining tokenizer and loading it
tokenizer = AutoTokenizer.from_pretrained('bert-base-chinese')
self.tokenizer = tokenizer
# 2. Define audio pipeline:
@data_pipeline.takes("wav")
@data_pipeline.provides("sig")
def audio_pipeline(wav):
sig = dataio.read_audio(wav)
return sig
dataset.add_dynamic_item(datasets, audio_pipeline)
# 3. Define text pipeline:
@data_pipeline.takes("transcript")
@data_pipeline.provides("wrd", "tokens_list", "tokens")
def text_pipeline(wrd):
wrd = "".join(wrd.split(" "))
yield wrd
tokens_list = tokenizer(wrd)["input_ids"]
yield tokens_list
tokens = np.array(tokens_list, dtype="int64")
# tokens = paddle.to_tensor(tokens_list, dtype="int64")
yield tokens
dataset.add_dynamic_item(datasets, text_pipeline)
# 4. Set output:
dataset.set_output_keys(
datasets,
["id", "sig", "wrd", "tokens"], )
# 5. If Dynamic Batching is used, we instantiate the needed samplers.
train_batch_sampler = None
valid_batch_sampler = None
if hparams["dynamic_batching"]:
from sampler import DynamicBatchSampler # noqa
dynamic_hparams = hparams["dynamic_batch_sampler"]
num_buckets = dynamic_hparams["num_buckets"]
train_batch_sampler = DynamicBatchSampler(
train_data,
dynamic_hparams["max_batch_len"],
num_buckets=num_buckets,
length_func=lambda x: x["duration"],
shuffle=dynamic_hparams["shuffle_ex"],
batch_ordering=dynamic_hparams["batch_ordering"], )
valid_batch_sampler = DynamicBatchSampler(
valid_data,
dynamic_hparams["max_batch_len"],
num_buckets=num_buckets,
length_func=lambda x: x["duration"],
shuffle=dynamic_hparams["shuffle_ex"],
batch_ordering=dynamic_hparams["batch_ordering"], )
return (train_data, valid_data, test_data, tokenizer,
train_batch_sampler, valid_batch_sampler, )
def setup_dataloader(self):
config = self.config.clone()
self.use_streamdata = config.get("use_stream_data", False)
self.use_sb = config.get("use_sb_pipeline", False)
if self.use_sb:
hparams_file = config.sb_pipeline_conf
with open(hparams_file, 'r', encoding='utf8') as fin:
hparams = load_hyperpyyaml(fin, None)
(train_data, valid_data, test_data, tokenizer, train_bsampler,
valid_bsampler, ) = self.dataio_prepare(hparams)
train_dataloader_opts = hparams["train_dataloader_opts"]
valid_dataloader_opts = hparams["valid_dataloader_opts"]
if train_bsampler is not None:
train_dataloader_opts = {
"batch_sampler": train_bsampler,
"num_workers": hparams["num_workers"],
}
if valid_bsampler is not None:
valid_dataloader_opts = {"batch_sampler": valid_bsampler}
if self.train:
self.train_loader = make_dataloader(
train_data, stage='train', **train_dataloader_opts)
self.valid_loader = make_dataloader(
valid_data,
stage='val',
**valid_dataloader_opts, )
logger.info("Setup train/valid Dataloader!")
else:
self.test_loader = make_dataloader(
test_data, stage='test', **hparams["test_dataloader_opts"])
else:
if self.train:
self.train_loader = DataLoaderFactory.get_dataloader(
'train', config, self.args)
self.valid_loader = DataLoaderFactory.get_dataloader(
'valid', config, self.args)
logger.info("Setup train/valid Dataloader!")
else:
decode_batch_size = config.get('decode', dict()).get(
'decode_batch_size', 1)
self.test_loader = DataLoaderFactory.get_dataloader(
'test', config, self.args)
self.align_loader = DataLoaderFactory.get_dataloader(
'align', config, self.args)
logger.info("Setup test/align Dataloader!")
def setup_model(self):
config = self.config
model_conf = config
with UpdateConfig(model_conf):
if self.use_sb:
model_conf.output_dim = self.tokenizer.vocab_size
else:
if self.train:
model_conf.input_dim = self.train_loader.feat_dim
model_conf.output_dim = self.train_loader.vocab_size
else:
model_conf.input_dim = self.test_loader.feat_dim
model_conf.output_dim = self.test_loader.vocab_size
model = HubertASR.from_config(model_conf)
model_dict = paddle.load(config.hubert_params_path)
model.set_state_dict(model_dict)
if self.parallel:
model = paddle.DataParallel(model, find_unused_parameters=True)
layer_tools.print_params(model, logger.info)
self.model = model
logger.info("Setup model!")
# setup speech augmentation for hubert
if hasattr(config, 'audio_augment') and self.train:
self.speech_augmentation = TimeDomainSpecAugment(
**config.audio_augment)
if not self.train:
return
train_config = config
model_optim_type = train_config.model_optim
model_optim_conf = train_config.model_optim_conf
logger.info("optim_model:{},{}", model_optim_type, model_optim_conf)
hubert_optim_type = train_config.hubert_optim
hubert_optim_conf = train_config.hubert_optim_conf
logger.info("optim_model:{},{}", hubert_optim_type, hubert_optim_conf)
model_scheduler_type = train_config.model_scheduler
model_scheduler_conf = train_config.model_scheduler_conf
hubert_scheduler_type = train_config.hubert_scheduler
hubert_scheduler_conf = train_config.hubert_scheduler_conf
model_scheduler_args = dict(
**{"learning_rate": model_optim_conf.lr,
"verbose": False}, **(dict(model_scheduler_conf)))
hubert_scheduler_args = dict(
**{"learning_rate": hubert_optim_conf.lr,
"verbose": False}, **(dict(hubert_scheduler_conf)))
model_lr_scheduler = LRSchedulerFactory.from_args(model_scheduler_type,
model_scheduler_args)
hubert_lr_scheduler = LRSchedulerFactory.from_args(
hubert_scheduler_type, hubert_scheduler_args)
def optimizer_args(
config,
optim_type,
optim_conf,
parameters,
lr_scheduler=None, ):
optim_arg = dict(optim_conf)
optim_arg.update({
"learning_rate":
lr_scheduler if lr_scheduler else optim_conf.lr,
"parameters":
parameters
})
return optim_arg
model_optimizer_args = optimizer_args(config, model_optim_type,
model_optim_conf, [{
'params':
model._layers.enc.parameters()
}, {
'params':
model._layers.ctc.parameters()
}] if self.parallel else [{
'params':
model.enc.parameters()
}, {
'params':
model.ctc.parameters()
}], model_lr_scheduler)
hubert_optimizer_args = optimizer_args(
config, hubert_optim_type, hubert_optim_conf,
model._layers.hubert.parameters() if self.parallel else
model.hubert.parameters(), hubert_lr_scheduler)
model_optimizer = OptimizerFactory.from_args(model_optim_type,
model_optimizer_args)
hubert_optimizer = OptimizerFactory.from_args(hubert_optim_type,
hubert_optimizer_args)
self.model_optimizer = model_optimizer
self.hubert_optimizer = hubert_optimizer
self.model_lr_scheduler = model_lr_scheduler
self.hubert_lr_scheduler = hubert_lr_scheduler
logger.info("Setup optimizer/lr_scheduler!")
class HubertASRTester(HubertASRTrainer):
def __init__(self, config, args):
super().__init__(config, args)
self.text_featurizer = TextFeaturizer(
unit_type=config.unit_type, vocab=config.vocab_filepath)
self.vocab_list = self.text_featurizer.vocab_list
def id2token(self, texts, texts_len):
""" ord() id to chr() chr """
trans = []
for text, n in zip(texts, texts_len):
n = n.numpy().item()
ids = text[:n]
trans.append(self.text_featurizer.defeaturize(ids.numpy().tolist()))
return trans
def compute_metrics(self, id, audio, audio_len, texts, texts_len,
fout=None):
decode_cfg = self.config.decode
errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if decode_cfg.error_rate_type == 'cer' else error_rate.word_errors
error_rate_func = error_rate.cer if decode_cfg.error_rate_type == 'cer' else error_rate.wer
start_time = time.time()
target_transcripts = self.id2token(texts, texts_len)
result_transcripts, result_tokenids = self.model.decode(
audio,
text_feature=self.text_featurizer,
decoding_method=decode_cfg.decoding_method,
beam_size=decode_cfg.beam_size)
decode_time = time.time() - start_time
for utt, target, result, rec_tids in zip(
id, target_transcripts, result_transcripts, result_tokenids):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
num_ins += 1
if fout:
fout.write({
"utt": utt,
"refs": [target],
"hyps": [result],
"hyps_tokenid": [rec_tids],
})
logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
logger.info("One example error rate [%s] = %f" % (
decode_cfg.error_rate_type, error_rate_func(target, result)))
return dict(
errors_sum=errors_sum,
len_refs=len_refs,
num_ins=num_ins, # num examples
error_rate=errors_sum / len_refs,
error_rate_type=decode_cfg.error_rate_type,
num_frames=audio_len.sum().numpy().item(),
decode_time=decode_time)
def sb_compute_metrics(self, id, sig, wrd, tokens, fout=None):
decode_cfg = self.config.decode
errors_sum, len_refs, num_ins = 0.0, 0, 0
errors_func = error_rate.char_errors if decode_cfg.error_rate_type == 'cer' else error_rate.word_errors
error_rate_func = error_rate.cer if decode_cfg.error_rate_type == 'cer' else error_rate.wer
start_time = time.time()
target_transcripts = wrd
result_transcripts, result_tokenids = self.model.decode(
sig[0],
text_feature=self.tokenizer,
decoding_method=decode_cfg.decoding_method,
beam_size=decode_cfg.beam_size,
sb_pipeline=True)
decode_time = time.time() - start_time
for utt, target, result, rec_tids in zip(
id, target_transcripts, result_transcripts, result_tokenids):
errors, len_ref = errors_func(target, result)
errors_sum += errors
len_refs += len_ref
num_ins += 1
if fout:
fout.write({
"utt": utt,
"refs": [target],
"hyps": [result],
"hyps_tokenid": [rec_tids],
})
logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")
logger.info("One example error rate [%s] = %f" % (
decode_cfg.error_rate_type, error_rate_func(target, result)))
return dict(
errors_sum=errors_sum,
len_refs=len_refs,
num_ins=num_ins, # num examples
error_rate=errors_sum / len_refs,
error_rate_type=decode_cfg.error_rate_type,
num_frames=sig[1].sum().numpy().item(),
decode_time=decode_time)
@mp_tools.rank_zero_only
@paddle.no_grad()
def test(self):
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
self.model.eval()
error_rate_type = None
errors_sum, len_refs, num_ins = 0.0, 0, 0
num_frames = 0.0
num_time = 0.0
# Initialized the decoder in model
decode_cfg = self.config.decode
vocab_list = self.vocab_list
decode_batch_size = decode_cfg.decode_batch_size
with jsonlines.open(self.args.result_file, 'w') as fout:
for i, batch in enumerate(self.test_loader):
if self.use_sb:
metrics = self.sb_compute_metrics(**batch, fout=fout)
else:
metrics = self.compute_metrics(*batch, fout=fout)
num_frames += metrics['num_frames']
num_time += metrics["decode_time"]
errors_sum += metrics['errors_sum']
len_refs += metrics['len_refs']
num_ins += metrics['num_ins']
error_rate_type = metrics['error_rate_type']
rtf = num_time / (num_frames)
logger.info(
"RTF: %f, Error rate [%s] (%d/?) = %f" %
(rtf, error_rate_type, num_ins, errors_sum / len_refs))
# logging
msg = "Test: "
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "Final error rate [%s] (%d/%d) = %f" % (
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
logger.info(msg)
err_meta_path = os.path.splitext(self.args.result_file)[0] + '.err'
err_type_str = "{}".format(error_rate_type)
with open(err_meta_path, 'w', encoding='utf8') as f:
data = json.dumps({
"epoch":
self.epoch,
"step":
self.iteration,
"rtf":
rtf,
error_rate_type:
errors_sum / len_refs,
"dataset_hour": (num_frames) / 1000.0 / 3600.0,
"process_hour":
num_time / 1000.0 / 3600.0,
"num_examples":
num_ins,
"err_sum":
errors_sum,
"ref_len":
len_refs,
"decode_method":
self.config.decode.decoding_method,
})
f.write(data + '\n')

@ -0,0 +1,17 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .hubert_ASR import HubertASR
from .hubert_ASR import HubertBase
__all__ = ["HubertASR", "HubertBase"]

@ -0,0 +1,368 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HubertASR model."""
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass
from dataclasses import is_dataclass
from typing import Dict
from typing import List
from typing import Tuple
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddlespeech.s2t.models.hubert.modules.hubert_model import HubertConfig
from paddlespeech.s2t.models.hubert.modules.hubert_model import HubertModel
from paddlespeech.s2t.models.hubert.modules.hubert_model import HubertPretrainingConfig
from paddlespeech.s2t.models.wav2vec2.modules.VanillaNN import VanillaNN
from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import SpecAugment
from paddlespeech.s2t.modules.ctc import CTCDecoderBase as CTC
from paddlespeech.s2t.modules.initializer import DefaultInitializerContext
from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import log_add
logger = Log(__name__).getlog()
class HubertASR(nn.Layer):
def __init__(self, config: dict):
super().__init__()
init_type = config.get("init_type", None)
with DefaultInitializerContext(init_type):
self.config = config
task_cfg = self.merge_with_parent(HubertPretrainingConfig,
dict(self.config.task_cfg))
model_cfg = self.merge_with_parent(HubertConfig,
dict(self.config.model_cfg))
hubert = HubertModel(model_cfg, task_cfg, [None])
self.normalize_wav = config.normalize_wav
self.output_norm = config.output_norm
if hasattr(config, 'spec_augment'):
self.spec_augment = SpecAugment(**config.spec_augment)
if config.freeze_hubert:
hubert.eval()
for parm in hubert.parameters():
parm.trainable = False
self.hubert = hubert
self.enc = VanillaNN(**config.enc)
self.ctc = CTC(**config.ctc,
odim=config.output_dim,
batch_average=False,
reduction='mean')
def merge_with_parent(self, dc: dataclass, cfg: dict):
assert is_dataclass(dc)
assert type(cfg) == dict
cfg = deepcopy(cfg)
def fix_cfg(cfg):
target_keys = set(dc.__dataclass_fields__.keys())
for k in list(cfg.keys()):
if k not in target_keys:
del cfg[k]
fix_cfg(cfg)
assert len(cfg) > 0
return dc(**cfg)
def forward(self, wav, wavs_lens_rate, target, target_lens):
if self.normalize_wav:
wav = F.layer_norm(wav, wav.shape)
# Extract wav2vec output
out = self.hubert.extract_features(wav)[0]
# We normalize the output if required
if self.output_norm:
out = F.layer_norm(out, out.shape)
if self.training and hasattr(self.config, 'spec_augment'):
feats = self.spec_augment(out)
else:
feats = out
x = self.enc(feats)
x_lens = (wavs_lens_rate * x.shape[1]).round().astype(paddle.int64)
ctc_loss = self.ctc(x, x_lens, target, target_lens)
return ctc_loss
@paddle.no_grad()
def decode(self,
feats: paddle.Tensor,
text_feature: Dict[str, int],
decoding_method: str,
beam_size: int,
tokenizer: str=None,
sb_pipeline=False):
batch_size = feats.shape[0]
if decoding_method == 'ctc_prefix_beam_search' and batch_size > 1:
logger.error(
f"decoding mode {decoding_method} must be running with batch_size == 1"
)
logger.error(f"current batch_size is {batch_size}")
if decoding_method == 'ctc_greedy_search':
if tokenizer is None and sb_pipeline is False:
hyps = self.ctc_greedy_search(feats)
res = [text_feature.defeaturize(hyp) for hyp in hyps]
res_tokenids = [hyp for hyp in hyps]
else:
if sb_pipeline is True:
hyps = self.ctc_greedy_search(feats.unsqueeze(-1))
else:
hyps = self.ctc_greedy_search(feats)
res = []
res_tokenids = []
for sequence in hyps:
# Decode token terms to words
predicted_tokens = text_feature.convert_ids_to_tokens(
sequence)
tmp_res = []
tmp_res_tokenids = []
for c in predicted_tokens:
if c == "[CLS]":
continue
elif c == "[SEP]" or c == "[PAD]":
break
else:
tmp_res.append(c)
tmp_res_tokenids.append(text_feature.vocab[c])
res.append(''.join(tmp_res))
res_tokenids.append(tmp_res_tokenids)
# ctc_prefix_beam_search and attention_rescoring only return one
# result in List[int], change it to List[List[int]] for compatible
# with other batch decoding mode
elif decoding_method == 'ctc_prefix_beam_search':
assert feats.shape[0] == 1
if tokenizer is None and sb_pipeline is False:
hyp = self.ctc_prefix_beam_search(feats, beam_size)
res = [text_feature.defeaturize(hyp)]
res_tokenids = [hyp]
else:
if sb_pipeline is True:
hyp = self.ctc_prefix_beam_search(
feats.unsqueeze(-1), beam_size)
else:
hyp = self.ctc_prefix_beam_search(feats, beam_size)
res = []
res_tokenids = []
predicted_tokens = text_feature.convert_ids_to_tokens(hyp)
tmp_res = []
tmp_res_tokenids = []
for c in predicted_tokens:
if c == "[CLS]":
continue
elif c == "[SEP]" or c == "[PAD]":
break
else:
tmp_res.append(c)
tmp_res_tokenids.append(text_feature.vocab[c])
res.append(''.join(tmp_res))
res_tokenids.append(tmp_res_tokenids)
else:
raise ValueError(
f"wav2vec2 not support decoding method: {decoding_method}")
return res, res_tokenids
@classmethod
def from_config(cls, config):
model = cls(config)
return model
def ctc_greedy_search(self, wav) -> List[List[int]]:
""" Apply CTC greedy search
Args:
speech (paddle.Tensor): (batch, max_len)
speech_length (paddle.Tensor): (batch, )
Returns:
List[List[int]]: best path result
"""
batch_size = wav.shape[0]
wav = wav[:, :, 0]
if self.normalize_wav:
wav = F.layer_norm(wav, wav.shape[1:])
# Extract wav2vec output
out = self.hubert.extract_features(wav)[0]
# We normalize the output if required
if self.output_norm:
out = F.layer_norm(out, out.shape[1:])
feats = out
x = self.enc(feats)
x_lens = x.shape[1]
ctc_probs = self.ctc.log_softmax(x) # (B, maxlen, vocab_size)
topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1)
topk_index = topk_index.view(batch_size, x_lens) # (B, maxlen)
hyps = [hyp.tolist() for hyp in topk_index]
hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
return hyps
def _ctc_prefix_beam_search(
self,
wav,
beam_size,
blank_id: int=0, ) -> Tuple[List[Tuple[int, float]], paddle.Tensor]:
""" CTC prefix beam search inner implementation
Args:
speech (paddle.Tensor): (batch, max_len, feat_dim)
speech_length (paddle.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
List[Tuple[int, float]]: nbest results, (N,1), (text, likelihood)
paddle.Tensor: encoder output, (1, max_len, encoder_dim),
it will be used for rescoring in attention rescoring mode
"""
wav = wav[:, :, 0]
if self.normalize_wav:
wav = F.layer_norm(wav, wav.shape[1:])
# Extract wav2vec output
out = self.hubert.extract_features(wav)[0]
# We normalize the output if required
if self.output_norm:
out = F.layer_norm(out, out.shape[1:])
feats = out
x = self.enc(feats)
maxlen = x.shape[1]
ctc_probs = self.ctc.log_softmax(x) # (1, maxlen, vocab_size)
ctc_probs = ctc_probs.squeeze(0)
# cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score))
# blank_ending_score and none_blank_ending_score in ln domain
cur_hyps = [(tuple(), (0.0, -float('inf')))]
# 2. CTC beam search step by step
for t in range(0, maxlen):
logp = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (-float('inf'), -float('inf')))
# 2.1 First beam prune: select topk best
top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,)
for s in top_k_index:
s = s.item()
ps = logp[s].item()
for prefix, (pb, pnb) in cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == blank_id: # blank
n_pb, n_pnb = next_hyps[prefix]
n_pb = log_add([n_pb, pb + ps, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
elif s == last:
# Update *ss -> *s;
n_pb, n_pnb = next_hyps[prefix]
n_pnb = log_add([n_pnb, pnb + ps])
next_hyps[prefix] = (n_pb, n_pnb)
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
else:
n_prefix = prefix + (s, )
n_pb, n_pnb = next_hyps[n_prefix]
n_pnb = log_add([n_pnb, pb + ps, pnb + ps])
next_hyps[n_prefix] = (n_pb, n_pnb)
# 2.2 Second beam prune
next_hyps = sorted(
next_hyps.items(),
key=lambda x: log_add(list(x[1])),
reverse=True)
cur_hyps = next_hyps[:beam_size]
hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps]
return hyps
def ctc_prefix_beam_search(self, wav, beam_size) -> List[int]:
""" Apply CTC prefix beam search
Args:
speech (paddle.Tensor): (batch, max_len, feat_dim)
speech_length (paddle.Tensor): (batch, )
beam_size (int): beam size for beam search
decoding_chunk_size (int): decoding chunk for dynamic chunk
trained model.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
0: used for training, it's prohibited here
simulate_streaming (bool): whether do encoder forward in a
streaming fashion
Returns:
List[int]: CTC prefix beam search nbest results
"""
hyps = self._ctc_prefix_beam_search(wav, beam_size)
return hyps[0][0]
class HubertBase(nn.Layer):
"""Hubert model"""
def __init__(self, config: dict):
super().__init__()
self.config = config
task_cfg = self.merge_with_parent(HubertPretrainingConfig,
dict(self.config.task_cfg))
model_cfg = self.merge_with_parent(HubertConfig,
dict(self.config.model_cfg))
hubert = HubertModel(model_cfg, task_cfg, [None])
self.hubert = hubert
@classmethod
def from_config(cls, configs: dict):
"""init model.
Args:
configs (dict): config dict.
Raises:
ValueError: raise when using not support encoder type.
Returns:
nn.Layer: HubertBase
"""
model = cls(configs)
return model
def merge_with_parent(self, dc: dataclass, cfg: dict):
assert is_dataclass(dc)
assert type(cfg) == dict
cfg = deepcopy(cfg)
def fix_cfg(cfg):
target_keys = set(dc.__dataclass_fields__.keys())
for k in list(cfg.keys()):
if k not in target_keys:
del cfg[k]
fix_cfg(cfg)
assert len(cfg) > 0
return dc(**cfg)
def forward(self, wav):
out = self.hubert.extract_features(wav)
return out

@ -0,0 +1,13 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,586 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Paddle Hubert model."""
from dataclasses import dataclass
from dataclasses import field
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import numpy as np
import paddle
import paddle.nn as nn
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import ChoiceEnum
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import compute_mask_indices
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import ConvFeatureExtractionModel
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import EXTRACTOR_MODE_CHOICES
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import get_available_activation_fns
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import GLU
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import GradMultiply
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import LAYER_TYPE_CHOICES
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import MASKING_DISTRIBUTION_CHOICES
from paddlespeech.s2t.models.wav2vec2.modules.wav2vec2_model import TransformerEncoder
from paddlespeech.s2t.modules.align import LayerNorm
from paddlespeech.s2t.modules.align import Linear
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
@dataclass
class HubertPretrainingConfig:
label_rate: float = field(
default=-1.0,
metadata={"help": "label frame rate. -1.0 for sequence label"}, )
sample_rate: int = field(
default=16_000,
metadata={
"help":
"target sample rate. audio files will be up/down "
"sampled to this rate"
}, )
normalize: bool = field(
default=False,
metadata={
"help": "if set, normalizes input to have 0 mean and unit variance"
}, )
enable_padding: bool = field(
default=False,
metadata={"help": "pad shorter samples instead of cropping"}, )
max_keep_size: Optional[int] = field(
default=None,
metadata={"help": "exclude sample longer than this"}, )
max_sample_size: Optional[int] = field(
default=None,
metadata={"help": "max sample size to crop to for batching"}, )
min_sample_size: Optional[int] = field(
default=None,
metadata={"help": "min sample size to crop to for batching"}, )
random_crop: Optional[bool] = field(
default=True,
metadata={"help": "always crop from the beginning if false"}, )
pad_audio: Optional[bool] = field(
default=False,
metadata={"help": "pad audio to the longest one in the batch if true"},
)
@dataclass
class HubertConfig:
label_rate: float
extractor_mode: EXTRACTOR_MODE_CHOICES = field(
default="default",
metadata={
"help":
"mode for feature extractor. default has a single group "
"norm with d groups in the first conv block, whereas layer_norm "
"has layer norms in every block (meant to use with normalize=True)"
}, )
encoder_layers: int = field(
default=12, metadata={"help": "num encoder layers in the transformer"})
encoder_embed_dim: int = field(
default=768, metadata={"help": "encoder embedding dimension"})
encoder_ffn_embed_dim: int = field(
default=3072, metadata={"help": "encoder embedding dimension for FFN"})
encoder_attention_heads: int = field(
default=12, metadata={"help": "num encoder attention heads"})
activation_fn: ChoiceEnum(get_available_activation_fns()) = field(
default="gelu", metadata={"help": "activation function to use"})
layer_type: LAYER_TYPE_CHOICES = field(
default="transformer", metadata={"help": "layer type in encoder"})
# dropouts
dropout: float = field(
default=0.1,
metadata={"help": "dropout probability for the transformer"}, )
attention_dropout: float = field(
default=0.1,
metadata={"help": "dropout probability for attention weights"}, )
activation_dropout: float = field(
default=0.0,
metadata={"help": "dropout probability after activation in FFN"}, )
encoder_layerdrop: float = field(
default=0.0,
metadata={"help": "probability of dropping a tarnsformer layer"}, )
dropout_input: float = field(
default=0.0,
metadata={"help": "dropout to apply to the input (after feat extr)"}, )
dropout_features: float = field(
default=0.0,
metadata={"help": "dropout to apply to the features (after feat extr)"},
)
final_dim: int = field(
default=0,
metadata={
"help":
"project final representations and targets to this many "
"dimensions. set to encoder_embed_dim is <= 0"
}, )
untie_final_proj: bool = field(
default=False,
metadata={"help": "use separate projection for each target"}, )
layer_norm_first: bool = field(
default=False,
metadata={"help": "apply layernorm first in the transformer"}, )
conv_feature_layers: str = field(
default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2",
metadata={
"help":
"string describing convolutional feature extraction "
"layers in form of a python list that contains "
"[(dim, kernel_size, stride), ...]"
}, )
conv_bias: bool = field(
default=False, metadata={"help": "include bias in conv encoder"})
logit_temp: float = field(
default=0.1, metadata={"help": "temperature to divide logits by"})
target_glu: bool = field(
default=False, metadata={"help": "adds projection + glu to targets"})
feature_grad_mult: float = field(
default=1.0,
metadata={"help": "multiply feature extractor var grads by this"}, )
# masking
mask_length: int = field(default=10, metadata={"help": "mask length"})
mask_prob: float = field(
default=0.65,
metadata={"help": "probability of replacing a token with mask"}, )
mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
default="static", metadata={"help": "how to choose mask length"})
mask_other: float = field(
default=0,
metadata={
"help":
"secondary mask argument "
"(used for more complex distributions), "
"see help in compute_mask_indicesh"
}, )
no_mask_overlap: bool = field(
default=False, metadata={"help": "whether to allow masks to overlap"})
mask_min_space: int = field(
default=1,
metadata={"help": "min space between spans (if no overlap is enabled)"},
)
# channel masking
mask_channel_length: int = field(
default=10,
metadata={"help": "length of the mask for features (channels)"}, )
mask_channel_prob: float = field(
default=0.0,
metadata={"help": "probability of replacing a feature with 0"}, )
mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
default="static",
metadata={"help": "how to choose mask length for channel masking"}, )
mask_channel_other: float = field(
default=0,
metadata={
"help":
"secondary mask argument "
"(used for more complex distributions), "
"see help in compute_mask_indicesh"
}, )
no_mask_channel_overlap: bool = field(
default=False,
metadata={"help": "whether to allow channel masks to overlap"}, )
mask_channel_min_space: int = field(
default=1,
metadata={"help": "min space between spans (if no overlap is enabled)"},
)
# positional embeddings
conv_pos: int = field(
default=128,
metadata={
"help": "number of filters for convolutional positional embeddings"
}, )
conv_pos_groups: int = field(
default=16,
metadata={
"help": "number of groups for convolutional positional embedding"
}, )
latent_temp: Tuple[float, float, float] = field(
default=(2, 0.5, 0.999995),
metadata={"help": "legacy (to be removed)"}, )
# loss computation
skip_masked: bool = field(
default=False,
metadata={"help": "skip computing losses over masked frames"}, )
skip_nomask: bool = field(
default=False,
metadata={"help": "skip computing losses over unmasked frames"}, )
checkpoint_activations: bool = field(
default=False,
metadata={
"help": "recompute activations and save memory for extra compute"
}, )
# FP16 optimization
required_seq_len_multiple: int = field(
default=2,
metadata={
"help":
"pad the input to encoder such that the sequence length is divisible by multiple"
}, )
# Conformer
depthwise_conv_kernel_size: int = field(
default=31,
metadata={
"help":
"depthwise-conv-kernel-size for convolution in conformer layer"
}, )
attn_type: str = field(
default="",
metadata={"help": "if espnet use ESPNET MHA"}, )
pos_enc_type: str = field(
default="abs",
metadata={"help": "Positional encoding type to use in conformer"}, )
fp16: bool = field(
default=False, metadata={"help": "If fp16 is being used"})
class HubertModel(nn.Layer):
def __init__(
self,
cfg: HubertConfig,
task_cfg: HubertPretrainingConfig,
dictionaries: List[Any], ) -> None:
super().__init__()
logger.info(f"HubertModel Config: {cfg}")
feature_enc_layers = eval(cfg.conv_feature_layers) # noqa
self.embed = feature_enc_layers[-1][0]
self.feature_extractor = ConvFeatureExtractionModel(
conv_layers=feature_enc_layers,
dropout=0.0,
mode=cfg.extractor_mode,
conv_bias=cfg.conv_bias, )
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate
self.post_extract_proj = (Linear(self.embed, cfg.encoder_embed_dim) if
self.embed != cfg.encoder_embed_dim else None)
self.mask_prob = cfg.mask_prob
self.mask_selection = cfg.mask_selection
self.mask_other = cfg.mask_other
self.mask_length = cfg.mask_length
self.no_mask_overlap = cfg.no_mask_overlap
self.mask_min_space = cfg.mask_min_space
self.mask_channel_prob = cfg.mask_channel_prob
self.mask_channel_selection = cfg.mask_channel_selection
self.mask_channel_other = cfg.mask_channel_other
self.mask_channel_length = cfg.mask_channel_length
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
self.mask_channel_min_space = cfg.mask_channel_min_space
self.dropout_input = nn.Dropout(cfg.dropout_input)
self.dropout_features = nn.Dropout(cfg.dropout_features)
self.feature_grad_mult = cfg.feature_grad_mult
self.logit_temp = cfg.logit_temp
self.skip_masked = cfg.skip_masked
self.skip_nomask = cfg.skip_nomask
final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
self.mask_emb = paddle.create_parameter(
shape=[cfg.encoder_embed_dim],
dtype='float32',
default_initializer=paddle.nn.initializer.Uniform(low=0), )
self.encoder = TransformerEncoder(cfg)
self.layer_norm = LayerNorm(self.embed)
self.target_glu = None
if cfg.target_glu:
self.target_glu = nn.Sequential(
Linear(final_dim, final_dim * 2), GLU())
self.untie_final_proj = cfg.untie_final_proj
if self.untie_final_proj:
self.final_proj = Linear(cfg.encoder_embed_dim,
final_dim * len(dictionaries))
else:
self.final_proj = Linear(cfg.encoder_embed_dim, final_dim)
# modules below are not needed during fine-tuning
if any([d is None for d in dictionaries]):
logger.info(
"cannot find dictionary. assume will be used for fine-tuning")
else:
self.num_classes = [len(d) for d in dictionaries]
self.label_embs_concat = paddle.create_parameter(
shape=[sum(self.num_classes), final_dim],
dtype='float32',
default_initializer=paddle.nn.initializer.Uniform(low=0), )
@classmethod
def build_model(cls, cfg: HubertConfig, task):
"""Build a new model instance."""
model = HubertModel(cfg, task.cfg, task.dictionaries)
return model
def apply_mask(self, x, padding_mask, target_list):
B, T, C = x.shape
if self.mask_prob > 0:
mask_indices = compute_mask_indices(
(B, T),
padding_mask,
self.mask_prob,
self.mask_length,
self.mask_selection,
self.mask_other,
min_masks=2,
no_overlap=self.no_mask_overlap,
min_space=self.mask_min_space, )
mask_indices = paddle.to_tensor(
mask_indices, dtype='int64', place=x.place)
x[mask_indices] = self.mask_emb
else:
mask_indices = None
if self.mask_channel_prob > 0:
mask_channel_indices = compute_mask_indices(
(B, C),
None,
self.mask_channel_prob,
self.mask_channel_length,
self.mask_channel_selection,
self.mask_channel_other,
no_overlap=self.no_mask_channel_overlap,
min_space=self.mask_channel_min_space, )
mask_channel_indices = (paddle.to_tensor(
mask_channel_indices, dtype='int64', place=x.place).unsqueeze(1)
.expand([-1, T, -1]))
x[mask_channel_indices] = 0
return x, mask_indices
def compute_nce(self, x, pos, negs):
neg_is_pos = (pos == negs).all(-1)
pos = pos.unsqueeze(0)
targets = paddle.concat([pos, negs], axis=0)
logits = paddle.nn.functional.cosine_similarity(
x.astype('float32'), targets.astype('float32'), axis=-1)
logits /= self.logit_temp
if paddle.any(neg_is_pos):
logits[1:][neg_is_pos] = float("-inf")
logits = logits.transpose([1, 0]) # (num_x, num_cls+1)
return logits
def forward_features(self, source: paddle.Tensor) -> paddle.Tensor:
if self.feature_grad_mult > 0:
features = self.feature_extractor(source)
if self.feature_grad_mult != 1.0:
features = GradMultiply.apply(features, self.feature_grad_mult)
else:
with paddle.no_grad():
features = self.feature_extractor(source)
return features
def forward_targets(
self,
features: paddle.Tensor,
target_list: List[paddle.Tensor],
) -> Tuple[paddle.Tensor, paddle.Tensor]:
# Trim features to ensure labels exist and then get aligned labels
feat_tsz = features.shape[2]
targ_tsz = min([t.shape[1] for t in target_list])
if self.feat2tar_ratio * feat_tsz > targ_tsz:
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
features = features[:, :, :feat_tsz]
target_inds = paddle.arange(feat_tsz).astype(
'float32') * self.feat2tar_ratio
target_list = [t[:, target_inds.astype('int64')] for t in target_list]
return features, target_list
def forward_padding_mask(
self,
features: paddle.Tensor,
padding_mask: paddle.Tensor, ) -> paddle.Tensor:
extra = padding_mask.shape[1] % features.shape[1]
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = paddle.reshape(
padding_mask, [padding_mask.shape[0], features.shape[1], -1])
padding_mask = paddle.all(padding_mask, axis=-1)
return padding_mask
def forward(
self,
source: paddle.Tensor,
target_list: Optional[List[paddle.Tensor]]=None,
padding_mask: Optional[paddle.Tensor]=None,
mask: bool=True,
features_only: bool=False,
output_layer: Optional[int]=None, ) -> Dict[str, paddle.Tensor]:
"""output layer is 1-based"""
features = self.forward_features(source)
if target_list is not None:
features, target_list = self.forward_targets(features, target_list)
features_pen = features.pow(2).mean()
features = features.transpose([0, 2, 1])
features = self.layer_norm(features)
unmasked_features = features.clone()
if padding_mask is not None:
padding_mask = self.forward_padding_mask(features, padding_mask)
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
features = self.dropout_input(features)
unmasked_features = self.dropout_features(unmasked_features)
if mask:
x, mask_indices = self.apply_mask(features, padding_mask,
target_list)
else:
x = features
mask_indices = None
# feature: (B, T, D), float
# target: (B, T), long
# x: (B, T, D), float
# padding_mask: (B, T), bool
# mask_indices: (B, T), bool
x, _ = self.encoder(
x,
padding_mask=padding_mask,
layer=None if output_layer is None else output_layer - 1, )
if features_only:
return {"x": x, "padding_mask": padding_mask, "features": features}
def compute_pred(self, proj_x, target, label_embs):
# compute logits for the i-th label set
y = paddle.index_select(
label_embs, index=target.astype('int64'), axis=0)
negs = paddle.expand(
label_embs.unsqueeze(1),
[label_embs.shape[0], proj_x.shape[0], label_embs.shape[-1]])
if self.target_glu:
y = self.target_glu(y)
negs = self.target_glu(negs)
# proj_x: (S, D)
# y: (S, D)
# negs: (Neg, S, D)
return self.compute_nce(proj_x, y, negs)
label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
if not self.skip_masked:
masked_indices = paddle.logical_and(~padding_mask, mask_indices)
proj_x_m = self.final_proj(x[masked_indices])
if self.untie_final_proj:
proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1)
else:
proj_x_m_list = [proj_x_m for _ in range(len(target_list))]
logit_m_list = [
compute_pred(proj_x_m, t[masked_indices], label_embs_list[i])
for i, (proj_x_m, t
) in enumerate(zip(proj_x_m_list, target_list))
]
else:
logit_m_list = [None for _ in target_list]
if not self.skip_nomask:
nomask_indices = paddle.logical_and(~padding_mask, ~mask_indices)
proj_x_u = self.final_proj(x[nomask_indices])
if self.untie_final_proj:
proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1)
else:
proj_x_u_list = [proj_x_u for _ in range(len(target_list))]
logit_u_list = [
compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i])
for i, (proj_x_u, t
) in enumerate(zip(proj_x_u_list, target_list))
]
else:
logit_u_list = [None for _ in target_list]
result = {
"logit_m_list": logit_m_list,
"logit_u_list": logit_u_list,
"padding_mask": padding_mask,
"features_pen": features_pen,
}
return result
def extract_features(
self,
source: paddle.Tensor,
padding_mask: Optional[paddle.Tensor]=None,
mask: bool=False,
ret_conv: bool=False,
output_layer: Optional[int]=None,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
res = self.forward(
source,
padding_mask=padding_mask,
mask=mask,
features_only=True,
output_layer=output_layer, )
feature = res["features"] if ret_conv else res["x"]
return feature, res["padding_mask"]
def get_logits(self, net_output, is_masked=True):
if is_masked:
logits_list = net_output["logit_m_list"]
else:
logits_list = net_output["logit_u_list"]
logits_list = [
paddle.cast(x, 'float32') for x in logits_list if x is not None
]
return logits_list
def get_targets(self, net_output, is_masked=True):
logits_list = self.get_logits(net_output, is_masked)
targets_list = [
paddle.zeros_like(x, dtype='int64') for x in logits_list
]
return targets_list
def get_extra_losses(self, net_output):
extra_losses = []
names = []
if "features_pen" in net_output:
extra_losses.append(net_output["features_pen"])
names.append("features_pen")
return extra_losses, names
def remove_pretraining_modules(self):
self.target_glu = None
self.final_proj = None

File diff suppressed because it is too large Load Diff

@ -27,8 +27,11 @@ from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import Spec
from paddlespeech.s2t.modules.ctc import CTCDecoderBase as CTC
from paddlespeech.s2t.modules.initializer import DefaultInitializerContext
from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import log_add
logger = Log(__name__).getlog()
class Wav2vec2ASR(nn.Layer):
def __init__(self, config: dict):

Loading…
Cancel
Save