commit
4a11257dcb
@ -1,44 +0,0 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
from dtaidistance import dtw_ndim
|
||||
|
||||
__all__ = [
|
||||
'dtw_distance',
|
||||
]
|
||||
|
||||
|
||||
def dtw_distance(xs: np.ndarray, ys: np.ndarray) -> float:
|
||||
"""Dynamic Time Warping.
|
||||
This function keeps a compact matrix, not the full warping paths matrix.
|
||||
Uses dynamic programming to compute:
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
wps[i, j] = (s1[i]-s2[j])**2 + min(
|
||||
wps[i-1, j ] + penalty, // vertical / insertion / expansion
|
||||
wps[i , j-1] + penalty, // horizontal / deletion / compression
|
||||
wps[i-1, j-1]) // diagonal / match
|
||||
|
||||
dtw = sqrt(wps[-1, -1])
|
||||
|
||||
Args:
|
||||
xs (np.ndarray): ref sequence, [T,D]
|
||||
ys (np.ndarray): hyp sequence, [T,D]
|
||||
|
||||
Returns:
|
||||
float: dtw distance
|
||||
"""
|
||||
return dtw_ndim.distance(xs, ys)
|
@ -0,0 +1,146 @@
|
||||
# VITS with CSMSC
|
||||
This example contains code used to train a [VITS](https://arxiv.org/abs/2106.06103) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html).
|
||||
|
||||
## Dataset
|
||||
### Download and Extract
|
||||
Download CSMSC from it's [Official Website](https://test.data-baker.com/data/index/source).
|
||||
|
||||
### Get MFA Result and Extract
|
||||
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get phonemes for VITS, the durations of MFA are not needed here.
|
||||
You can download from here [baker_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/BZNSYP/with_tone/baker_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) of our repo.
|
||||
|
||||
## Get Started
|
||||
Assume the path to the dataset is `~/datasets/BZNSYP`.
|
||||
Assume the path to the MFA result of CSMSC is `./baker_alignment_tone`.
|
||||
Run the command below to
|
||||
1. **source path**.
|
||||
2. preprocess the dataset.
|
||||
3. train the model.
|
||||
4. synthesize wavs.
|
||||
- synthesize waveform from `metadata.jsonl`.
|
||||
- synthesize waveform from a text file.
|
||||
|
||||
```bash
|
||||
./run.sh
|
||||
```
|
||||
You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to use only one stage, for example, running the following command will only preprocess the dataset.
|
||||
```bash
|
||||
./run.sh --stage 0 --stop-stage 0
|
||||
```
|
||||
### Data Preprocessing
|
||||
```bash
|
||||
./local/preprocess.sh ${conf_path}
|
||||
```
|
||||
When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below.
|
||||
|
||||
```text
|
||||
dump
|
||||
├── dev
|
||||
│ ├── norm
|
||||
│ └── raw
|
||||
├── phone_id_map.txt
|
||||
├── speaker_id_map.txt
|
||||
├── test
|
||||
│ ├── norm
|
||||
│ └── raw
|
||||
└── train
|
||||
├── feats_stats.npy
|
||||
├── norm
|
||||
└── raw
|
||||
```
|
||||
The dataset is split into 3 parts, namely `train`, `dev`, and` test`, each of which contains a `norm` and `raw` subfolder. The raw folder contains wave and linear spectrogram of each utterance, while the norm folder contains normalized ones. The statistics used to normalize features are computed from the training set, which is located in `dump/train/feats_stats.npy`.
|
||||
|
||||
Also, there is a `metadata.jsonl` in each subfolder. It is a table-like file that contains phones, text_lengths, feats, feats_lengths, the path of linear spectrogram features, the path of raw waves, speaker, and the id of each utterance.
|
||||
|
||||
### Model Training
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path}
|
||||
```
|
||||
`./local/train.sh` calls `${BIN_DIR}/train.py`.
|
||||
Here's the complete help message.
|
||||
```text
|
||||
usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA]
|
||||
[--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR]
|
||||
[--ngpu NGPU] [--phones-dict PHONES_DICT]
|
||||
|
||||
Train a VITS model.
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--config CONFIG config file to overwrite default config.
|
||||
--train-metadata TRAIN_METADATA
|
||||
training data.
|
||||
--dev-metadata DEV_METADATA
|
||||
dev data.
|
||||
--output-dir OUTPUT_DIR
|
||||
output dir.
|
||||
--ngpu NGPU if ngpu == 0, use cpu.
|
||||
--phones-dict PHONES_DICT
|
||||
phone vocabulary file.
|
||||
```
|
||||
1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`.
|
||||
2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder.
|
||||
3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are saved in `checkpoints/` inside this directory.
|
||||
4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
|
||||
5. `--phones-dict` is the path of the phone vocabulary file.
|
||||
|
||||
### Synthesizing
|
||||
|
||||
`./local/synthesize.sh` calls `${BIN_DIR}/synthesize.py`, which can synthesize waveform from `metadata.jsonl`.
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name}
|
||||
```
|
||||
```text
|
||||
usage: synthesize.py [-h] [--config CONFIG] [--ckpt CKPT]
|
||||
[--phones_dict PHONES_DICT] [--ngpu NGPU]
|
||||
[--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR]
|
||||
|
||||
Synthesize with VITS
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--config CONFIG Config of VITS.
|
||||
--ckpt CKPT Checkpoint file of VITS.
|
||||
--phones_dict PHONES_DICT
|
||||
phone vocabulary file.
|
||||
--ngpu NGPU if ngpu == 0, use cpu.
|
||||
--test_metadata TEST_METADATA
|
||||
test metadata.
|
||||
--output_dir OUTPUT_DIR
|
||||
output dir.
|
||||
```
|
||||
`./local/synthesize_e2e.sh` calls `${BIN_DIR}/synthesize_e2e.py`, which can synthesize waveform from text file.
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name}
|
||||
```
|
||||
```text
|
||||
usage: synthesize_e2e.py [-h] [--config CONFIG] [--ckpt CKPT]
|
||||
[--phones_dict PHONES_DICT] [--lang LANG]
|
||||
[--inference_dir INFERENCE_DIR] [--ngpu NGPU]
|
||||
[--text TEXT] [--output_dir OUTPUT_DIR]
|
||||
|
||||
Synthesize with VITS
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--config CONFIG Config of VITS.
|
||||
--ckpt CKPT Checkpoint file of VITS.
|
||||
--phones_dict PHONES_DICT
|
||||
phone vocabulary file.
|
||||
--lang LANG Choose model language. zh or en
|
||||
--inference_dir INFERENCE_DIR
|
||||
dir to save inference models
|
||||
--ngpu NGPU if ngpu == 0, use cpu.
|
||||
--text TEXT text to synthesize, a 'utt_id sentence' pair per line.
|
||||
--output_dir OUTPUT_DIR
|
||||
output dir.
|
||||
```
|
||||
1. `--config`, `--ckpt`, and `--phones_dict` are arguments for acoustic model, which correspond to the 3 files in the VITS pretrained model.
|
||||
2. `--lang` is the model language, which can be `zh` or `en`.
|
||||
3. `--test_metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder.
|
||||
4. `--text` is the text file, which contains sentences to synthesize.
|
||||
5. `--output_dir` is the directory to save synthesized audio files.
|
||||
6. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
|
||||
|
||||
## Pretrained Model
|
@ -1 +0,0 @@
|
||||
tmp
|
@ -1,19 +0,0 @@
|
||||
# 1xt2x
|
||||
|
||||
Convert Deepspeech 1.8 released model to 2.x.
|
||||
|
||||
## Model source directory
|
||||
* Deepspeech2x
|
||||
|
||||
## Expriment directory
|
||||
* aishell
|
||||
* librispeech
|
||||
* baidu_en8k
|
||||
|
||||
# The released model
|
||||
|
||||
Acoustic Model | Training Data | Hours of Speech | Token-based | CER | WER
|
||||
:-------------:| :------------:| :---------------: | :---------: | :---: | :----:
|
||||
Ds2 Offline Aishell 1xt2x model| Aishell Dataset | 151 h | Char-based | 0.080447 |
|
||||
Ds2 Offline Librispeech 1xt2x model | Librispeech Dataset | 960 h | Word-based | | 0.068548
|
||||
Ds2 Offline Baidu en8k 1x2x model | Baidu Internal English Dataset | 8628 h |Word-based | | 0.054112
|
@ -1,5 +0,0 @@
|
||||
exp
|
||||
data
|
||||
*log
|
||||
tmp
|
||||
nohup*
|
@ -1 +0,0 @@
|
||||
[]
|
@ -1,65 +0,0 @@
|
||||
# https://yaml.org/type/float.html
|
||||
###########################################
|
||||
# Data #
|
||||
###########################################
|
||||
train_manifest: data/manifest.train
|
||||
dev_manifest: data/manifest.dev
|
||||
test_manifest: data/manifest.test
|
||||
min_input_len: 0.0
|
||||
max_input_len: 27.0 # second
|
||||
min_output_len: 0.0
|
||||
max_output_len: .inf
|
||||
min_output_input_ratio: 0.00
|
||||
max_output_input_ratio: .inf
|
||||
|
||||
###########################################
|
||||
# Dataloader #
|
||||
###########################################
|
||||
batch_size: 64 # one gpu
|
||||
mean_std_filepath: data/mean_std.npz
|
||||
unit_type: char
|
||||
vocab_filepath: data/vocab.txt
|
||||
augmentation_config: conf/augmentation.json
|
||||
random_seed: 0
|
||||
spm_model_prefix:
|
||||
spectrum_type: linear
|
||||
feat_dim:
|
||||
delta_delta: False
|
||||
stride_ms: 10.0
|
||||
window_ms: 20.0
|
||||
n_fft: None
|
||||
max_freq: None
|
||||
target_sample_rate: 16000
|
||||
use_dB_normalization: True
|
||||
target_dB: -20
|
||||
dither: 1.0
|
||||
keep_transcription_text: False
|
||||
sortagrad: True
|
||||
shuffle_method: batch_shuffle
|
||||
num_workers: 2
|
||||
|
||||
############################################
|
||||
# Network Architecture #
|
||||
############################################
|
||||
num_conv_layers: 2
|
||||
num_rnn_layers: 3
|
||||
rnn_layer_size: 1024
|
||||
use_gru: True
|
||||
share_rnn_weights: False
|
||||
blank_id: 4333
|
||||
|
||||
###########################################
|
||||
# Training #
|
||||
###########################################
|
||||
n_epoch: 80
|
||||
accum_grad: 1
|
||||
lr: 2e-3
|
||||
lr_decay: 0.83
|
||||
weight_decay: 1e-06
|
||||
global_grad_clip: 3.0
|
||||
log_interval: 100
|
||||
checkpoint:
|
||||
kbest_n: 50
|
||||
latest_n: 5
|
||||
|
||||
|
@ -1,10 +0,0 @@
|
||||
decode_batch_size: 32
|
||||
error_rate_type: cer
|
||||
decoding_method: ctc_beam_search
|
||||
lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm
|
||||
alpha: 2.6
|
||||
beta: 5.0
|
||||
beam_size: 300
|
||||
cutoff_prob: 0.99
|
||||
cutoff_top_n: 40
|
||||
num_proc_bsearch: 8
|
@ -1,69 +0,0 @@
|
||||
#!/bin/bash
|
||||
if [ $# != 1 ];then
|
||||
echo "usage: ${0} ckpt_dir"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ckpt_dir=$1
|
||||
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh
|
||||
|
||||
mkdir -p data
|
||||
TARGET_DIR=${MAIN_ROOT}/dataset
|
||||
mkdir -p ${TARGET_DIR}
|
||||
|
||||
bash local/download_model.sh ${ckpt_dir}
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cd ${ckpt_dir}
|
||||
tar xzvf aishell_model_v1.8_to_v2.x.tar.gz
|
||||
cd -
|
||||
mv ${ckpt_dir}/mean_std.npz data/
|
||||
mv ${ckpt_dir}/vocab.txt data/
|
||||
|
||||
|
||||
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
||||
# download data, generate manifests
|
||||
python3 ${TARGET_DIR}/aishell/aishell.py \
|
||||
--manifest_prefix="data/manifest" \
|
||||
--target_dir="${TARGET_DIR}/aishell"
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Prepare Aishell failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
for dataset in train dev test; do
|
||||
mv data/manifest.${dataset} data/manifest.${dataset}.raw
|
||||
done
|
||||
fi
|
||||
|
||||
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# format manifest with tokenids, vocab size
|
||||
for dataset in train dev test; do
|
||||
{
|
||||
python3 ${MAIN_ROOT}/utils/format_data.py \
|
||||
--cmvn_path "data/mean_std.npz" \
|
||||
--unit_type "char" \
|
||||
--vocab_path="data/vocab.txt" \
|
||||
--manifest_path="data/manifest.${dataset}.raw" \
|
||||
--output_path="data/manifest.${dataset}"
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Formt mnaifest failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
} &
|
||||
done
|
||||
wait
|
||||
fi
|
||||
|
||||
echo "Aishell data preparation done."
|
||||
exit 0
|
@ -1,23 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
. ${MAIN_ROOT}/utils/utility.sh
|
||||
|
||||
DIR=data/lm
|
||||
mkdir -p ${DIR}
|
||||
|
||||
URL='https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm'
|
||||
MD5="29e02312deb2e59b3c8686c7966d4fe3"
|
||||
TARGET=${DIR}/zh_giga.no_cna_cmn.prune01244.klm
|
||||
|
||||
|
||||
echo "Start downloading the language model. The language model is large, please wait for a moment ..."
|
||||
download $URL $MD5 $TARGET > /dev/null 2>&1
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Fail to download the language model!"
|
||||
exit 1
|
||||
else
|
||||
echo "Download the language model sucessfully"
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -1,25 +0,0 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
if [ $# != 1 ];then
|
||||
echo "usage: ${0} ckpt_dir"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ckpt_dir=$1
|
||||
|
||||
. ${MAIN_ROOT}/utils/utility.sh
|
||||
|
||||
URL='https://deepspeech.bj.bcebos.com/mandarin_models/aishell_model_v1.8_to_v2.x.tar.gz'
|
||||
MD5=87e7577d4bea737dbf3e8daab37aa808
|
||||
TARGET=${ckpt_dir}/aishell_model_v1.8_to_v2.x.tar.gz
|
||||
|
||||
|
||||
echo "Download Aishell model ..."
|
||||
download $URL $MD5 $TARGET
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Fail to download Aishell model!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -1,36 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ $# != 4 ];then
|
||||
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix model_type"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
echo "using $ngpu gpus..."
|
||||
|
||||
config_path=$1
|
||||
decode_config_path=$2
|
||||
ckpt_prefix=$3
|
||||
model_type=$4
|
||||
|
||||
# download language model
|
||||
bash local/download_lm_ch.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
python3 -u ${BIN_DIR}/test.py \
|
||||
--ngpu ${ngpu} \
|
||||
--config ${config_path} \
|
||||
--decode_cfg ${decode_config_path} \
|
||||
--result_file ${ckpt_prefix}.rsl \
|
||||
--checkpoint_path ${ckpt_prefix} \
|
||||
--model_type ${model_type}
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -1,17 +0,0 @@
|
||||
export MAIN_ROOT=`realpath ${PWD}/../../../../`
|
||||
export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../`
|
||||
|
||||
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||
export LC_ALL=C
|
||||
|
||||
export PYTHONDONTWRITEBYTECODE=1
|
||||
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||
export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH}
|
||||
|
||||
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
|
||||
|
||||
MODEL=deepspeech2
|
||||
export BIN_DIR=${LOCAL_DEEPSPEECH2}/src_deepspeech2x/bin
|
||||
echo "BIN_DIR "${BIN_DIR}
|
@ -1,29 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
source path.sh
|
||||
|
||||
stage=0
|
||||
stop_stage=100
|
||||
conf_path=conf/deepspeech2.yaml
|
||||
decode_conf_path=conf/tuning/decode.yaml
|
||||
avg_num=1
|
||||
model_type=offline
|
||||
gpus=2
|
||||
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||
|
||||
v18_ckpt=aishell_v1.8
|
||||
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
|
||||
echo "checkpoint name ${ckpt}"
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# prepare data
|
||||
mkdir -p exp/${ckpt}/checkpoints
|
||||
bash ./local/data.sh exp/${ckpt}/checkpoints || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# test ckpt avg_n
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1
|
||||
fi
|
||||
|
@ -1,5 +0,0 @@
|
||||
exp
|
||||
data
|
||||
*log
|
||||
tmp
|
||||
nohup*
|
@ -1 +0,0 @@
|
||||
[]
|
@ -1,64 +0,0 @@
|
||||
# https://yaml.org/type/float.html
|
||||
###########################################
|
||||
# Data #
|
||||
###########################################
|
||||
train_manifest: data/manifest.train
|
||||
dev_manifest: data/manifest.dev
|
||||
test_manifest: data/manifest.test-clean
|
||||
min_input_len: 0.0
|
||||
max_input_len: .inf # second
|
||||
min_output_len: 0.0
|
||||
max_output_len: .inf
|
||||
min_output_input_ratio: 0.00
|
||||
max_output_input_ratio: .inf
|
||||
|
||||
###########################################
|
||||
# Dataloader #
|
||||
###########################################
|
||||
batch_size: 64 # one gpu
|
||||
mean_std_filepath: data/mean_std.npz
|
||||
unit_type: char
|
||||
vocab_filepath: data/vocab.txt
|
||||
augmentation_config: conf/augmentation.json
|
||||
random_seed: 0
|
||||
spm_model_prefix:
|
||||
spectrum_type: linear
|
||||
feat_dim:
|
||||
delta_delta: False
|
||||
stride_ms: 10.0
|
||||
window_ms: 20.0
|
||||
n_fft: None
|
||||
max_freq: None
|
||||
target_sample_rate: 16000
|
||||
use_dB_normalization: True
|
||||
target_dB: -20
|
||||
dither: 1.0
|
||||
keep_transcription_text: False
|
||||
sortagrad: True
|
||||
shuffle_method: batch_shuffle
|
||||
num_workers: 2
|
||||
|
||||
############################################
|
||||
# Network Architecture #
|
||||
############################################
|
||||
num_conv_layers: 2
|
||||
num_rnn_layers: 3
|
||||
rnn_layer_size: 1024
|
||||
use_gru: True
|
||||
share_rnn_weights: False
|
||||
blank_id: 28
|
||||
|
||||
###########################################
|
||||
# Training #
|
||||
###########################################
|
||||
n_epoch: 80
|
||||
accum_grad: 1
|
||||
lr: 2e-3
|
||||
lr_decay: 0.83
|
||||
weight_decay: 1e-06
|
||||
global_grad_clip: 3.0
|
||||
log_interval: 100
|
||||
checkpoint:
|
||||
kbest_n: 50
|
||||
latest_n: 5
|
||||
|
@ -1,10 +0,0 @@
|
||||
decode_batch_size: 32
|
||||
error_rate_type: wer
|
||||
decoding_method: ctc_beam_search
|
||||
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
|
||||
alpha: 1.4
|
||||
beta: 0.35
|
||||
beam_size: 500
|
||||
cutoff_prob: 1.0
|
||||
cutoff_top_n: 40
|
||||
num_proc_bsearch: 8
|
@ -1,85 +0,0 @@
|
||||
#!/bin/bash
|
||||
if [ $# != 1 ];then
|
||||
echo "usage: ${0} ckpt_dir"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ckpt_dir=$1
|
||||
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
unit_type=char
|
||||
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh
|
||||
|
||||
mkdir -p data
|
||||
TARGET_DIR=${MAIN_ROOT}/dataset
|
||||
mkdir -p ${TARGET_DIR}
|
||||
|
||||
|
||||
bash local/download_model.sh ${ckpt_dir}
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cd ${ckpt_dir}
|
||||
tar xzvf baidu_en8k_v1.8_to_v2.x.tar.gz
|
||||
cd -
|
||||
mv ${ckpt_dir}/mean_std.npz data/
|
||||
mv ${ckpt_dir}/vocab.txt data/
|
||||
|
||||
|
||||
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
||||
# download data, generate manifests
|
||||
python3 ${TARGET_DIR}/librispeech/librispeech.py \
|
||||
--manifest_prefix="data/manifest" \
|
||||
--target_dir="${TARGET_DIR}/librispeech" \
|
||||
--full_download="True"
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Prepare LibriSpeech failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
for set in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||
mv data/manifest.${set} data/manifest.${set}.raw
|
||||
done
|
||||
|
||||
rm -rf data/manifest.train.raw data/manifest.dev.raw data/manifest.test.raw
|
||||
for set in train-clean-100 train-clean-360 train-other-500; do
|
||||
cat data/manifest.${set}.raw >> data/manifest.train.raw
|
||||
done
|
||||
|
||||
for set in dev-clean dev-other; do
|
||||
cat data/manifest.${set}.raw >> data/manifest.dev.raw
|
||||
done
|
||||
|
||||
for set in test-clean test-other; do
|
||||
cat data/manifest.${set}.raw >> data/manifest.test.raw
|
||||
done
|
||||
fi
|
||||
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# format manifest with tokenids, vocab size
|
||||
for set in train dev test dev-clean dev-other test-clean test-other; do
|
||||
{
|
||||
python3 ${MAIN_ROOT}/utils/format_data.py \
|
||||
--cmvn_path "data/mean_std.npz" \
|
||||
--unit_type ${unit_type} \
|
||||
--vocab_path="data/vocab.txt" \
|
||||
--manifest_path="data/manifest.${set}.raw" \
|
||||
--output_path="data/manifest.${set}"
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Formt mnaifest.${set} failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
}&
|
||||
done
|
||||
wait
|
||||
fi
|
||||
|
||||
echo "LibriSpeech Data preparation done."
|
||||
exit 0
|
||||
|
@ -1,22 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
. ${MAIN_ROOT}/utils/utility.sh
|
||||
|
||||
DIR=data/lm
|
||||
mkdir -p ${DIR}
|
||||
|
||||
URL=https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm
|
||||
MD5="099a601759d467cd0a8523ff939819c5"
|
||||
TARGET=${DIR}/common_crawl_00.prune01111.trie.klm
|
||||
|
||||
echo "Start downloading the language model. The language model is large, please wait for a moment ..."
|
||||
download $URL $MD5 $TARGET > /dev/null 2>&1
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Fail to download the language model!"
|
||||
exit 1
|
||||
else
|
||||
echo "Download the language model sucessfully"
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -1,25 +0,0 @@
|
||||
#! /usr/bin/env bash
|
||||
if [ $# != 1 ];then
|
||||
echo "usage: ${0} ckpt_dir"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ckpt_dir=$1
|
||||
|
||||
|
||||
. ${MAIN_ROOT}/utils/utility.sh
|
||||
|
||||
URL='https://deepspeech.bj.bcebos.com/eng_models/baidu_en8k_v1.8_to_v2.x.tar.gz'
|
||||
MD5=c1676be8505cee436e6f312823e9008c
|
||||
TARGET=${ckpt_dir}/baidu_en8k_v1.8_to_v2.x.tar.gz
|
||||
|
||||
|
||||
echo "Download BaiduEn8k model ..."
|
||||
download $URL $MD5 $TARGET
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Fail to download BaiduEn8k model!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -1,36 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ $# != 4 ];then
|
||||
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix model_type"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
echo "using $ngpu gpus..."
|
||||
|
||||
config_path=$1
|
||||
decode_config_path=$2
|
||||
ckpt_prefix=$3
|
||||
model_type=$4
|
||||
|
||||
# download language model
|
||||
bash local/download_lm_en.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
python3 -u ${BIN_DIR}/test.py \
|
||||
--ngpu ${ngpu} \
|
||||
--config ${config_path} \
|
||||
--decode_cfg ${decode_config_path} \
|
||||
--result_file ${ckpt_prefix}.rsl \
|
||||
--checkpoint_path ${ckpt_prefix} \
|
||||
--model_type ${model_type}
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -1,17 +0,0 @@
|
||||
export MAIN_ROOT=`realpath ${PWD}/../../../../`
|
||||
export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../`
|
||||
|
||||
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||
export LC_ALL=C
|
||||
|
||||
export PYTHONDONTWRITEBYTECODE=1
|
||||
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||
export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH}
|
||||
|
||||
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
|
||||
|
||||
MODEL=deepspeech2
|
||||
export BIN_DIR=${LOCAL_DEEPSPEECH2}/src_deepspeech2x/bin
|
||||
echo "BIN_DIR "${BIN_DIR}
|
@ -1,29 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
source path.sh
|
||||
|
||||
stage=0
|
||||
stop_stage=100
|
||||
conf_path=conf/deepspeech2.yaml
|
||||
decode_conf_path=conf/tuning/decode.yaml
|
||||
avg_num=1
|
||||
model_type=offline
|
||||
gpus=0
|
||||
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||
|
||||
v18_ckpt=baidu_en8k_v1.8
|
||||
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
|
||||
echo "checkpoint name ${ckpt}"
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# prepare data
|
||||
mkdir -p exp/${ckpt}/checkpoints
|
||||
bash ./local/data.sh exp/${ckpt}/checkpoints || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# test ckpt avg_n
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1
|
||||
fi
|
||||
|
@ -1,5 +0,0 @@
|
||||
exp
|
||||
data
|
||||
*log
|
||||
tmp
|
||||
nohup*
|
@ -1 +0,0 @@
|
||||
[]
|
@ -1,64 +0,0 @@
|
||||
# https://yaml.org/type/float.html
|
||||
###########################################
|
||||
# Data #
|
||||
###########################################
|
||||
train_manifest: data/manifest.train
|
||||
dev_manifest: data/manifest.dev
|
||||
test_manifest: data/manifest.test-clean
|
||||
min_input_len: 0.0
|
||||
max_input_len: 1000.0 # second
|
||||
min_output_len: 0.0
|
||||
max_output_len: .inf
|
||||
min_output_input_ratio: 0.00
|
||||
max_output_input_ratio: .inf
|
||||
|
||||
###########################################
|
||||
# Dataloader #
|
||||
###########################################
|
||||
batch_size: 64 # one gpu
|
||||
mean_std_filepath: data/mean_std.npz
|
||||
unit_type: char
|
||||
vocab_filepath: data/vocab.txt
|
||||
augmentation_config: conf/augmentation.json
|
||||
random_seed: 0
|
||||
spm_model_prefix:
|
||||
spectrum_type: linear
|
||||
feat_dim:
|
||||
delta_delta: False
|
||||
stride_ms: 10.0
|
||||
window_ms: 20.0
|
||||
n_fft: None
|
||||
max_freq: None
|
||||
target_sample_rate: 16000
|
||||
use_dB_normalization: True
|
||||
target_dB: -20
|
||||
dither: 1.0
|
||||
keep_transcription_text: False
|
||||
sortagrad: True
|
||||
shuffle_method: batch_shuffle
|
||||
num_workers: 2
|
||||
|
||||
############################################
|
||||
# Network Architecture #
|
||||
############################################
|
||||
num_conv_layers: 2
|
||||
num_rnn_layers: 3
|
||||
rnn_layer_size: 2048
|
||||
use_gru: False
|
||||
share_rnn_weights: True
|
||||
blank_id: 28
|
||||
|
||||
###########################################
|
||||
# Training #
|
||||
###########################################
|
||||
n_epoch: 80
|
||||
accum_grad: 1
|
||||
lr: 2e-3
|
||||
lr_decay: 0.83
|
||||
weight_decay: 1e-06
|
||||
global_grad_clip: 3.0
|
||||
log_interval: 100
|
||||
checkpoint:
|
||||
kbest_n: 50
|
||||
latest_n: 5
|
||||
|
@ -1,10 +0,0 @@
|
||||
decode_batch_size: 32
|
||||
error_rate_type: wer
|
||||
decoding_method: ctc_beam_search
|
||||
lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
|
||||
alpha: 2.5
|
||||
beta: 0.3
|
||||
beam_size: 500
|
||||
cutoff_prob: 1.0
|
||||
cutoff_top_n: 40
|
||||
num_proc_bsearch: 8
|
@ -1,83 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ $# != 1 ];then
|
||||
echo "usage: ${0} ckpt_dir"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ckpt_dir=$1
|
||||
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
unit_type=char
|
||||
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh
|
||||
|
||||
mkdir -p data
|
||||
TARGET_DIR=${MAIN_ROOT}/dataset
|
||||
mkdir -p ${TARGET_DIR}
|
||||
|
||||
bash local/download_model.sh ${ckpt_dir}
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cd ${ckpt_dir}
|
||||
tar xzvf librispeech_v1.8_to_v2.x.tar.gz
|
||||
cd -
|
||||
mv ${ckpt_dir}/mean_std.npz data/
|
||||
mv ${ckpt_dir}/vocab.txt data/
|
||||
|
||||
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
||||
# download data, generate manifests
|
||||
python3 ${TARGET_DIR}/librispeech/librispeech.py \
|
||||
--manifest_prefix="data/manifest" \
|
||||
--target_dir="${TARGET_DIR}/librispeech" \
|
||||
--full_download="True"
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Prepare LibriSpeech failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
for set in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
|
||||
mv data/manifest.${set} data/manifest.${set}.raw
|
||||
done
|
||||
|
||||
rm -rf data/manifest.train.raw data/manifest.dev.raw data/manifest.test.raw
|
||||
for set in train-clean-100 train-clean-360 train-other-500; do
|
||||
cat data/manifest.${set}.raw >> data/manifest.train.raw
|
||||
done
|
||||
|
||||
for set in dev-clean dev-other; do
|
||||
cat data/manifest.${set}.raw >> data/manifest.dev.raw
|
||||
done
|
||||
|
||||
for set in test-clean test-other; do
|
||||
cat data/manifest.${set}.raw >> data/manifest.test.raw
|
||||
done
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# format manifest with tokenids, vocab size
|
||||
for set in train dev test dev-clean dev-other test-clean test-other; do
|
||||
{
|
||||
python3 ${MAIN_ROOT}/utils/format_data.py \
|
||||
--cmvn_path "data/mean_std.npz" \
|
||||
--unit_type ${unit_type} \
|
||||
--vocab_path="data/vocab.txt" \
|
||||
--manifest_path="data/manifest.${set}.raw" \
|
||||
--output_path="data/manifest.${set}"
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Formt mnaifest.${set} failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
}&
|
||||
done
|
||||
wait
|
||||
fi
|
||||
|
||||
echo "LibriSpeech Data preparation done."
|
||||
exit 0
|
||||
|
@ -1,22 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
. ${MAIN_ROOT}/utils/utility.sh
|
||||
|
||||
DIR=data/lm
|
||||
mkdir -p ${DIR}
|
||||
|
||||
URL=https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm
|
||||
MD5="099a601759d467cd0a8523ff939819c5"
|
||||
TARGET=${DIR}/common_crawl_00.prune01111.trie.klm
|
||||
|
||||
echo "Start downloading the language model. The language model is large, please wait for a moment ..."
|
||||
download $URL $MD5 $TARGET > /dev/null 2>&1
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Fail to download the language model!"
|
||||
exit 1
|
||||
else
|
||||
echo "Download the language model sucessfully"
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -1,25 +0,0 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
if [ $# != 1 ];then
|
||||
echo "usage: ${0} ckpt_dir"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ckpt_dir=$1
|
||||
|
||||
. ${MAIN_ROOT}/utils/utility.sh
|
||||
|
||||
URL='https://deepspeech.bj.bcebos.com/eng_models/librispeech_v1.8_to_v2.x.tar.gz'
|
||||
MD5=a06d9aadb560ea113984dc98d67232c8
|
||||
TARGET=${ckpt_dir}/librispeech_v1.8_to_v2.x.tar.gz
|
||||
|
||||
|
||||
echo "Download LibriSpeech model ..."
|
||||
download $URL $MD5 $TARGET
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Fail to download LibriSpeech model!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -1,36 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ $# != 4 ];then
|
||||
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix model_type"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
echo "using $ngpu gpus..."
|
||||
|
||||
config_path=$1
|
||||
decode_config_path=$2
|
||||
ckpt_prefix=$3
|
||||
model_type=$4
|
||||
|
||||
# download language model
|
||||
bash local/download_lm_en.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
python3 -u ${BIN_DIR}/test.py \
|
||||
--ngpu ${ngpu} \
|
||||
--config ${config_path} \
|
||||
--decode_cfg ${decode_config_path} \
|
||||
--result_file ${ckpt_prefix}.rsl \
|
||||
--checkpoint_path ${ckpt_prefix} \
|
||||
--model_type ${model_type}
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -1,16 +0,0 @@
|
||||
export MAIN_ROOT=`realpath ${PWD}/../../../../`
|
||||
export LOCAL_DEEPSPEECH2=`realpath ${PWD}/../`
|
||||
|
||||
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||
export LC_ALL=C
|
||||
|
||||
export PYTHONDONTWRITEBYTECODE=1
|
||||
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||
export PYTHONPATH=${LOCAL_DEEPSPEECH2}:${PYTHONPATH}
|
||||
|
||||
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
|
||||
|
||||
MODEL=deepspeech2
|
||||
export BIN_DIR=${LOCAL_DEEPSPEECH2}/src_deepspeech2x/bin
|
@ -1,28 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
source path.sh
|
||||
|
||||
stage=0
|
||||
stop_stage=100
|
||||
conf_path=conf/deepspeech2.yaml
|
||||
decode_conf_path=conf/tuning/decode.yaml
|
||||
avg_num=1
|
||||
model_type=offline
|
||||
gpus=1
|
||||
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||
|
||||
v18_ckpt=librispeech_v1.8
|
||||
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
|
||||
echo "checkpoint name ${ckpt}"
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# prepare data
|
||||
mkdir -p exp/${ckpt}/checkpoints
|
||||
bash ./local/data.sh exp/${ckpt}/checkpoints || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# test ckpt avg_n
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${v18_ckpt} ${model_type}|| exit -1
|
||||
fi
|
@ -1,370 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.fluid import core
|
||||
from paddle.nn import functional as F
|
||||
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
|
||||
#TODO(Hui Zhang): remove fluid import
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
########### hack logging #############
|
||||
logger.warn = logger.warning
|
||||
|
||||
########### hack paddle #############
|
||||
paddle.half = 'float16'
|
||||
paddle.float = 'float32'
|
||||
paddle.double = 'float64'
|
||||
paddle.short = 'int16'
|
||||
paddle.int = 'int32'
|
||||
paddle.long = 'int64'
|
||||
paddle.uint16 = 'uint16'
|
||||
paddle.cdouble = 'complex128'
|
||||
|
||||
|
||||
def convert_dtype_to_string(tensor_dtype):
|
||||
"""
|
||||
Convert the data type in numpy to the data type in Paddle
|
||||
Args:
|
||||
tensor_dtype(core.VarDesc.VarType): the data type in numpy.
|
||||
Returns:
|
||||
core.VarDesc.VarType: the data type in Paddle.
|
||||
"""
|
||||
dtype = tensor_dtype
|
||||
if dtype == core.VarDesc.VarType.FP32:
|
||||
return paddle.float32
|
||||
elif dtype == core.VarDesc.VarType.FP64:
|
||||
return paddle.float64
|
||||
elif dtype == core.VarDesc.VarType.FP16:
|
||||
return paddle.float16
|
||||
elif dtype == core.VarDesc.VarType.INT32:
|
||||
return paddle.int32
|
||||
elif dtype == core.VarDesc.VarType.INT16:
|
||||
return paddle.int16
|
||||
elif dtype == core.VarDesc.VarType.INT64:
|
||||
return paddle.int64
|
||||
elif dtype == core.VarDesc.VarType.BOOL:
|
||||
return paddle.bool
|
||||
elif dtype == core.VarDesc.VarType.BF16:
|
||||
# since there is still no support for bfloat16 in NumPy,
|
||||
# uint16 is used for casting bfloat16
|
||||
return paddle.uint16
|
||||
elif dtype == core.VarDesc.VarType.UINT8:
|
||||
return paddle.uint8
|
||||
elif dtype == core.VarDesc.VarType.INT8:
|
||||
return paddle.int8
|
||||
elif dtype == core.VarDesc.VarType.COMPLEX64:
|
||||
return paddle.complex64
|
||||
elif dtype == core.VarDesc.VarType.COMPLEX128:
|
||||
return paddle.complex128
|
||||
else:
|
||||
raise ValueError("Not supported tensor dtype %s" % dtype)
|
||||
|
||||
|
||||
if not hasattr(paddle, 'softmax'):
|
||||
logger.warn("register user softmax to paddle, remove this when fixed!")
|
||||
setattr(paddle, 'softmax', paddle.nn.functional.softmax)
|
||||
|
||||
if not hasattr(paddle, 'log_softmax'):
|
||||
logger.warn("register user log_softmax to paddle, remove this when fixed!")
|
||||
setattr(paddle, 'log_softmax', paddle.nn.functional.log_softmax)
|
||||
|
||||
if not hasattr(paddle, 'sigmoid'):
|
||||
logger.warn("register user sigmoid to paddle, remove this when fixed!")
|
||||
setattr(paddle, 'sigmoid', paddle.nn.functional.sigmoid)
|
||||
|
||||
if not hasattr(paddle, 'log_sigmoid'):
|
||||
logger.warn("register user log_sigmoid to paddle, remove this when fixed!")
|
||||
setattr(paddle, 'log_sigmoid', paddle.nn.functional.log_sigmoid)
|
||||
|
||||
if not hasattr(paddle, 'relu'):
|
||||
logger.warn("register user relu to paddle, remove this when fixed!")
|
||||
setattr(paddle, 'relu', paddle.nn.functional.relu)
|
||||
|
||||
|
||||
def cat(xs, dim=0):
|
||||
return paddle.concat(xs, axis=dim)
|
||||
|
||||
|
||||
if not hasattr(paddle, 'cat'):
|
||||
logger.warn(
|
||||
"override cat of paddle if exists or register, remove this when fixed!")
|
||||
paddle.cat = cat
|
||||
|
||||
|
||||
########### hack paddle.Tensor #############
|
||||
def item(x: paddle.Tensor):
|
||||
return x.numpy().item()
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'item'):
|
||||
logger.warn(
|
||||
"override item of paddle.Tensor if exists or register, remove this when fixed!"
|
||||
)
|
||||
paddle.Tensor.item = item
|
||||
|
||||
|
||||
def func_long(x: paddle.Tensor):
|
||||
return paddle.cast(x, paddle.long)
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'long'):
|
||||
logger.warn(
|
||||
"override long of paddle.Tensor if exists or register, remove this when fixed!"
|
||||
)
|
||||
paddle.Tensor.long = func_long
|
||||
|
||||
if not hasattr(paddle.Tensor, 'numel'):
|
||||
logger.warn(
|
||||
"override numel of paddle.Tensor if exists or register, remove this when fixed!"
|
||||
)
|
||||
paddle.Tensor.numel = paddle.numel
|
||||
|
||||
|
||||
def new_full(x: paddle.Tensor,
|
||||
size: Union[List[int], Tuple[int], paddle.Tensor],
|
||||
fill_value: Union[float, int, bool, paddle.Tensor],
|
||||
dtype=None):
|
||||
return paddle.full(size, fill_value, dtype=x.dtype)
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'new_full'):
|
||||
logger.warn(
|
||||
"override new_full of paddle.Tensor if exists or register, remove this when fixed!"
|
||||
)
|
||||
paddle.Tensor.new_full = new_full
|
||||
|
||||
|
||||
def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor:
|
||||
if convert_dtype_to_string(xs.dtype) == paddle.bool:
|
||||
xs = xs.astype(paddle.int)
|
||||
return xs.equal(
|
||||
paddle.to_tensor(
|
||||
ys, dtype=convert_dtype_to_string(xs.dtype), place=xs.place))
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'eq'):
|
||||
logger.warn(
|
||||
"override eq of paddle.Tensor if exists or register, remove this when fixed!"
|
||||
)
|
||||
paddle.Tensor.eq = eq
|
||||
|
||||
if not hasattr(paddle, 'eq'):
|
||||
logger.warn(
|
||||
"override eq of paddle if exists or register, remove this when fixed!")
|
||||
paddle.eq = eq
|
||||
|
||||
|
||||
def contiguous(xs: paddle.Tensor) -> paddle.Tensor:
|
||||
return xs
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'contiguous'):
|
||||
logger.warn(
|
||||
"override contiguous of paddle.Tensor if exists or register, remove this when fixed!"
|
||||
)
|
||||
paddle.Tensor.contiguous = contiguous
|
||||
|
||||
|
||||
def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
|
||||
nargs = len(args)
|
||||
assert (nargs <= 1)
|
||||
s = paddle.shape(xs)
|
||||
if nargs == 1:
|
||||
return s[args[0]]
|
||||
else:
|
||||
return s
|
||||
|
||||
|
||||
#`to_static` do not process `size` property, maybe some `paddle` api dependent on it.
|
||||
logger.warn(
|
||||
"override size of paddle.Tensor "
|
||||
"(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!"
|
||||
)
|
||||
paddle.Tensor.size = size
|
||||
|
||||
|
||||
def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
|
||||
return xs.reshape(args)
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'view'):
|
||||
logger.warn("register user view to paddle.Tensor, remove this when fixed!")
|
||||
paddle.Tensor.view = view
|
||||
|
||||
|
||||
def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor:
|
||||
return xs.reshape(ys.size())
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'view_as'):
|
||||
logger.warn(
|
||||
"register user view_as to paddle.Tensor, remove this when fixed!")
|
||||
paddle.Tensor.view_as = view_as
|
||||
|
||||
|
||||
def is_broadcastable(shp1, shp2):
|
||||
for a, b in zip(shp1[::-1], shp2[::-1]):
|
||||
if a == 1 or b == 1 or a == b:
|
||||
pass
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def masked_fill(xs: paddle.Tensor,
|
||||
mask: paddle.Tensor,
|
||||
value: Union[float, int]):
|
||||
assert is_broadcastable(xs.shape, mask.shape) is True
|
||||
bshape = paddle.broadcast_shape(xs.shape, mask.shape)
|
||||
mask = mask.broadcast_to(bshape)
|
||||
trues = paddle.ones_like(xs) * value
|
||||
xs = paddle.where(mask, trues, xs)
|
||||
return xs
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'masked_fill'):
|
||||
logger.warn(
|
||||
"register user masked_fill to paddle.Tensor, remove this when fixed!")
|
||||
paddle.Tensor.masked_fill = masked_fill
|
||||
|
||||
|
||||
def masked_fill_(xs: paddle.Tensor,
|
||||
mask: paddle.Tensor,
|
||||
value: Union[float, int]) -> paddle.Tensor:
|
||||
assert is_broadcastable(xs.shape, mask.shape) is True
|
||||
bshape = paddle.broadcast_shape(xs.shape, mask.shape)
|
||||
mask = mask.broadcast_to(bshape)
|
||||
trues = paddle.ones_like(xs) * value
|
||||
ret = paddle.where(mask, trues, xs)
|
||||
paddle.assign(ret.detach(), output=xs)
|
||||
return xs
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'masked_fill_'):
|
||||
logger.warn(
|
||||
"register user masked_fill_ to paddle.Tensor, remove this when fixed!")
|
||||
paddle.Tensor.masked_fill_ = masked_fill_
|
||||
|
||||
|
||||
def fill_(xs: paddle.Tensor, value: Union[float, int]) -> paddle.Tensor:
|
||||
val = paddle.full_like(xs, value)
|
||||
paddle.assign(val.detach(), output=xs)
|
||||
return xs
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'fill_'):
|
||||
logger.warn("register user fill_ to paddle.Tensor, remove this when fixed!")
|
||||
paddle.Tensor.fill_ = fill_
|
||||
|
||||
|
||||
def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor:
|
||||
return paddle.tile(xs, size)
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'repeat'):
|
||||
logger.warn(
|
||||
"register user repeat to paddle.Tensor, remove this when fixed!")
|
||||
paddle.Tensor.repeat = repeat
|
||||
|
||||
if not hasattr(paddle.Tensor, 'softmax'):
|
||||
logger.warn(
|
||||
"register user softmax to paddle.Tensor, remove this when fixed!")
|
||||
setattr(paddle.Tensor, 'softmax', paddle.nn.functional.softmax)
|
||||
|
||||
if not hasattr(paddle.Tensor, 'sigmoid'):
|
||||
logger.warn(
|
||||
"register user sigmoid to paddle.Tensor, remove this when fixed!")
|
||||
setattr(paddle.Tensor, 'sigmoid', paddle.nn.functional.sigmoid)
|
||||
|
||||
if not hasattr(paddle.Tensor, 'relu'):
|
||||
logger.warn("register user relu to paddle.Tensor, remove this when fixed!")
|
||||
setattr(paddle.Tensor, 'relu', paddle.nn.functional.relu)
|
||||
|
||||
|
||||
def type_as(x: paddle.Tensor, other: paddle.Tensor) -> paddle.Tensor:
|
||||
return x.astype(other.dtype)
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'type_as'):
|
||||
logger.warn(
|
||||
"register user type_as to paddle.Tensor, remove this when fixed!")
|
||||
setattr(paddle.Tensor, 'type_as', type_as)
|
||||
|
||||
|
||||
def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor:
|
||||
assert len(args) == 1
|
||||
if isinstance(args[0], str): # dtype
|
||||
return x.astype(args[0])
|
||||
elif isinstance(args[0], paddle.Tensor): #Tensor
|
||||
return x.astype(args[0].dtype)
|
||||
else: # Device
|
||||
return x
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'to'):
|
||||
logger.warn("register user to to paddle.Tensor, remove this when fixed!")
|
||||
setattr(paddle.Tensor, 'to', to)
|
||||
|
||||
|
||||
def func_float(x: paddle.Tensor) -> paddle.Tensor:
|
||||
return x.astype(paddle.float)
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'float'):
|
||||
logger.warn("register user float to paddle.Tensor, remove this when fixed!")
|
||||
setattr(paddle.Tensor, 'float', func_float)
|
||||
|
||||
|
||||
def func_int(x: paddle.Tensor) -> paddle.Tensor:
|
||||
return x.astype(paddle.int)
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'int'):
|
||||
logger.warn("register user int to paddle.Tensor, remove this when fixed!")
|
||||
setattr(paddle.Tensor, 'int', func_int)
|
||||
|
||||
|
||||
def tolist(x: paddle.Tensor) -> List[Any]:
|
||||
return x.numpy().tolist()
|
||||
|
||||
|
||||
if not hasattr(paddle.Tensor, 'tolist'):
|
||||
logger.warn(
|
||||
"register user tolist to paddle.Tensor, remove this when fixed!")
|
||||
setattr(paddle.Tensor, 'tolist', tolist)
|
||||
|
||||
|
||||
########### hack paddle.nn #############
|
||||
class GLU(nn.Layer):
|
||||
"""Gated Linear Units (GLU) Layer"""
|
||||
|
||||
def __init__(self, dim: int=-1):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, xs):
|
||||
return F.glu(xs, axis=self.dim)
|
||||
|
||||
|
||||
if not hasattr(paddle.nn, 'GLU'):
|
||||
logger.warn("register user GLU to paddle.nn, remove this when fixed!")
|
||||
setattr(paddle.nn, 'GLU', GLU)
|
@ -1,59 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Evaluation for DeepSpeech2 model."""
|
||||
from src_deepspeech2x.test_model import DeepSpeech2Tester as Tester
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.s2t.training.cli import default_argument_parser
|
||||
from paddlespeech.s2t.utils.utility import print_arguments
|
||||
|
||||
|
||||
def main_sp(config, args):
|
||||
exp = Tester(config, args)
|
||||
exp.setup()
|
||||
exp.run_test()
|
||||
|
||||
|
||||
def main(config, args):
|
||||
main_sp(config, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = default_argument_parser()
|
||||
parser.add_argument(
|
||||
"--model_type", type=str, default='offline', help='offline/online')
|
||||
# save asr result to
|
||||
parser.add_argument(
|
||||
"--result_file", type=str, help="path of save the asr result")
|
||||
args = parser.parse_args()
|
||||
print_arguments(args, globals())
|
||||
print("model_type:{}".format(args.model_type))
|
||||
|
||||
# https://yaml.org/type/float.html
|
||||
config = CfgNode(new_allowed=True)
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
if args.decode_cfg:
|
||||
decode_confs = CfgNode(new_allowed=True)
|
||||
decode_confs.merge_from_file(args.decode_cfg)
|
||||
config.decode = decode_confs
|
||||
if args.opts:
|
||||
config.merge_from_list(args.opts)
|
||||
config.freeze()
|
||||
print(config)
|
||||
if args.dump_config:
|
||||
with open(args.dump_config, 'w') as f:
|
||||
print(config, file=f)
|
||||
|
||||
main(config, args)
|
@ -1,13 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -1,17 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .deepspeech2 import DeepSpeech2InferModel
|
||||
from .deepspeech2 import DeepSpeech2Model
|
||||
|
||||
__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel']
|
@ -1,275 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Deepspeech2 ASR Model"""
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from src_deepspeech2x.models.ds2.rnn import RNNStack
|
||||
|
||||
from paddlespeech.s2t.models.ds2.conv import ConvStack
|
||||
from paddlespeech.s2t.modules.ctc import CTCDecoder
|
||||
from paddlespeech.s2t.utils import layer_tools
|
||||
from paddlespeech.s2t.utils.checkpoint import Checkpoint
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
__all__ = ['DeepSpeech2Model', 'DeepSpeech2InferModel']
|
||||
|
||||
|
||||
class CRNNEncoder(nn.Layer):
|
||||
def __init__(self,
|
||||
feat_size,
|
||||
dict_size,
|
||||
num_conv_layers=2,
|
||||
num_rnn_layers=3,
|
||||
rnn_size=1024,
|
||||
use_gru=False,
|
||||
share_rnn_weights=True):
|
||||
super().__init__()
|
||||
self.rnn_size = rnn_size
|
||||
self.feat_size = feat_size # 161 for linear
|
||||
self.dict_size = dict_size
|
||||
|
||||
self.conv = ConvStack(feat_size, num_conv_layers)
|
||||
|
||||
i_size = self.conv.output_height # H after conv stack
|
||||
self.rnn = RNNStack(
|
||||
i_size=i_size,
|
||||
h_size=rnn_size,
|
||||
num_stacks=num_rnn_layers,
|
||||
use_gru=use_gru,
|
||||
share_rnn_weights=share_rnn_weights)
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return self.rnn_size * 2
|
||||
|
||||
def forward(self, audio, audio_len):
|
||||
"""Compute Encoder outputs
|
||||
|
||||
Args:
|
||||
audio (Tensor): [B, Tmax, D]
|
||||
text (Tensor): [B, Umax]
|
||||
audio_len (Tensor): [B]
|
||||
text_len (Tensor): [B]
|
||||
Returns:
|
||||
x (Tensor): encoder outputs, [B, T, D]
|
||||
x_lens (Tensor): encoder length, [B]
|
||||
"""
|
||||
# [B, T, D] -> [B, D, T]
|
||||
audio = audio.transpose([0, 2, 1])
|
||||
# [B, D, T] -> [B, C=1, D, T]
|
||||
x = audio.unsqueeze(1)
|
||||
x_lens = audio_len
|
||||
|
||||
# convolution group
|
||||
x, x_lens = self.conv(x, x_lens)
|
||||
x_val = x.numpy()
|
||||
|
||||
# convert data from convolution feature map to sequence of vectors
|
||||
#B, C, D, T = paddle.shape(x) # not work under jit
|
||||
x = x.transpose([0, 3, 1, 2]) #[B, T, C, D]
|
||||
#x = x.reshape([B, T, C * D]) #[B, T, C*D] # not work under jit
|
||||
x = x.reshape([0, 0, -1]) #[B, T, C*D]
|
||||
|
||||
# remove padding part
|
||||
x, x_lens = self.rnn(x, x_lens) #[B, T, D]
|
||||
return x, x_lens
|
||||
|
||||
|
||||
class DeepSpeech2Model(nn.Layer):
|
||||
"""The DeepSpeech2 network structure.
|
||||
|
||||
:param audio_data: Audio spectrogram data layer.
|
||||
:type audio_data: Variable
|
||||
:param text_data: Transcription text data layer.
|
||||
:type text_data: Variable
|
||||
:param audio_len: Valid sequence length data layer.
|
||||
:type audio_len: Variable
|
||||
:param masks: Masks data layer to reset padding.
|
||||
:type masks: Variable
|
||||
:param dict_size: Dictionary size for tokenized transcription.
|
||||
:type dict_size: int
|
||||
:param num_conv_layers: Number of stacking convolution layers.
|
||||
:type num_conv_layers: int
|
||||
:param num_rnn_layers: Number of stacking RNN layers.
|
||||
:type num_rnn_layers: int
|
||||
:param rnn_size: RNN layer size (dimension of RNN cells).
|
||||
:type rnn_size: int
|
||||
:param use_gru: Use gru if set True. Use simple rnn if set False.
|
||||
:type use_gru: bool
|
||||
:param share_rnn_weights: Whether to share input-hidden weights between
|
||||
forward and backward direction RNNs.
|
||||
It is only available when use_gru=False.
|
||||
:type share_weights: bool
|
||||
:return: A tuple of an output unnormalized log probability layer (
|
||||
before softmax) and a ctc cost layer.
|
||||
:rtype: tuple of LayerOutput
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
feat_size,
|
||||
dict_size,
|
||||
num_conv_layers=2,
|
||||
num_rnn_layers=3,
|
||||
rnn_size=1024,
|
||||
use_gru=False,
|
||||
share_rnn_weights=True,
|
||||
blank_id=0):
|
||||
super().__init__()
|
||||
self.encoder = CRNNEncoder(
|
||||
feat_size=feat_size,
|
||||
dict_size=dict_size,
|
||||
num_conv_layers=num_conv_layers,
|
||||
num_rnn_layers=num_rnn_layers,
|
||||
rnn_size=rnn_size,
|
||||
use_gru=use_gru,
|
||||
share_rnn_weights=share_rnn_weights)
|
||||
assert (self.encoder.output_size == rnn_size * 2)
|
||||
|
||||
self.decoder = CTCDecoder(
|
||||
odim=dict_size, # <blank> is in vocab
|
||||
enc_n_units=self.encoder.output_size,
|
||||
blank_id=blank_id, # first token is <blank>
|
||||
dropout_rate=0.0,
|
||||
reduction=True, # sum
|
||||
batch_average=True) # sum / batch_size
|
||||
|
||||
def forward(self, audio, audio_len, text, text_len):
|
||||
"""Compute Model loss
|
||||
|
||||
Args:
|
||||
audio (Tensor): [B, T, D]
|
||||
audio_len (Tensor): [B]
|
||||
text (Tensor): [B, U]
|
||||
text_len (Tensor): [B]
|
||||
|
||||
Returns:
|
||||
loss (Tensor): [1]
|
||||
"""
|
||||
eouts, eouts_len = self.encoder(audio, audio_len)
|
||||
loss = self.decoder(eouts, eouts_len, text, text_len)
|
||||
return loss
|
||||
|
||||
@paddle.no_grad()
|
||||
def decode(self, audio, audio_len):
|
||||
# decoders only accept string encoded in utf-8
|
||||
|
||||
# Make sure the decoder has been initialized
|
||||
eouts, eouts_len = self.encoder(audio, audio_len)
|
||||
probs = self.decoder.softmax(eouts)
|
||||
batch_size = probs.shape[0]
|
||||
self.decoder.reset_decoder(batch_size=batch_size)
|
||||
self.decoder.next(probs, eouts_len)
|
||||
trans_best, trans_beam = self.decoder.decode()
|
||||
return trans_best
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, dataloader, config, checkpoint_path):
|
||||
"""Build a DeepSpeech2Model model from a pretrained model.
|
||||
Parameters
|
||||
----------
|
||||
dataloader: paddle.io.DataLoader
|
||||
|
||||
config: yacs.config.CfgNode
|
||||
model configs
|
||||
|
||||
checkpoint_path: Path or str
|
||||
the path of pretrained model checkpoint, without extension name
|
||||
|
||||
Returns
|
||||
-------
|
||||
DeepSpeech2Model
|
||||
The model built from pretrained result.
|
||||
"""
|
||||
model = cls(feat_size=dataloader.collate_fn.feature_size,
|
||||
dict_size=len(dataloader.collate_fn.vocab_list),
|
||||
num_conv_layers=config.num_conv_layers,
|
||||
num_rnn_layers=config.num_rnn_layers,
|
||||
rnn_size=config.rnn_layer_size,
|
||||
use_gru=config.use_gru,
|
||||
share_rnn_weights=config.share_rnn_weights)
|
||||
infos = Checkpoint().load_parameters(
|
||||
model, checkpoint_path=checkpoint_path)
|
||||
logger.info(f"checkpoint info: {infos}")
|
||||
layer_tools.summary(model)
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
"""Build a DeepSpeec2Model from config
|
||||
Parameters
|
||||
|
||||
config: yacs.config.CfgNode
|
||||
config
|
||||
Returns
|
||||
-------
|
||||
DeepSpeech2Model
|
||||
The model built from config.
|
||||
"""
|
||||
model = cls(feat_size=config.feat_size,
|
||||
dict_size=config.dict_size,
|
||||
num_conv_layers=config.num_conv_layers,
|
||||
num_rnn_layers=config.num_rnn_layers,
|
||||
rnn_size=config.rnn_layer_size,
|
||||
use_gru=config.use_gru,
|
||||
share_rnn_weights=config.share_rnn_weights,
|
||||
blank_id=config.blank_id)
|
||||
return model
|
||||
|
||||
|
||||
class DeepSpeech2InferModel(DeepSpeech2Model):
|
||||
def __init__(self,
|
||||
feat_size,
|
||||
dict_size,
|
||||
num_conv_layers=2,
|
||||
num_rnn_layers=3,
|
||||
rnn_size=1024,
|
||||
use_gru=False,
|
||||
share_rnn_weights=True,
|
||||
blank_id=0):
|
||||
super().__init__(
|
||||
feat_size=feat_size,
|
||||
dict_size=dict_size,
|
||||
num_conv_layers=num_conv_layers,
|
||||
num_rnn_layers=num_rnn_layers,
|
||||
rnn_size=rnn_size,
|
||||
use_gru=use_gru,
|
||||
share_rnn_weights=share_rnn_weights,
|
||||
blank_id=blank_id)
|
||||
|
||||
def forward(self, audio, audio_len):
|
||||
"""export model function
|
||||
|
||||
Args:
|
||||
audio (Tensor): [B, T, D]
|
||||
audio_len (Tensor): [B]
|
||||
|
||||
Returns:
|
||||
probs: probs after softmax
|
||||
"""
|
||||
eouts, eouts_len = self.encoder(audio, audio_len)
|
||||
probs = self.decoder.softmax(eouts)
|
||||
return probs, eouts_len
|
||||
|
||||
def export(self):
|
||||
static_model = paddle.jit.to_static(
|
||||
self,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, None, self.encoder.feat_size],
|
||||
dtype='float32'), # audio, [B,T,D]
|
||||
paddle.static.InputSpec(shape=[None],
|
||||
dtype='int64'), # audio_length, [B]
|
||||
])
|
||||
return static_model
|
@ -1,334 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
from paddle.nn import initializer as I
|
||||
|
||||
from paddlespeech.s2t.modules.activation import brelu
|
||||
from paddlespeech.s2t.modules.mask import make_non_pad_mask
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
__all__ = ['RNNStack']
|
||||
|
||||
|
||||
class RNNCell(nn.RNNCellBase):
|
||||
r"""
|
||||
Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it
|
||||
computes the outputs and updates states.
|
||||
The formula used is as follows:
|
||||
.. math::
|
||||
h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})
|
||||
y_{t} & = h_{t}
|
||||
|
||||
where :math:`act` is for :attr:`activation`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
activation="tanh",
|
||||
weight_ih_attr=None,
|
||||
weight_hh_attr=None,
|
||||
bias_ih_attr=None,
|
||||
bias_hh_attr=None,
|
||||
name=None):
|
||||
super().__init__()
|
||||
std = 1.0 / math.sqrt(hidden_size)
|
||||
self.weight_hh = self.create_parameter(
|
||||
(hidden_size, hidden_size),
|
||||
weight_hh_attr,
|
||||
default_initializer=I.Uniform(-std, std))
|
||||
self.bias_ih = None
|
||||
self.bias_hh = self.create_parameter(
|
||||
(hidden_size, ),
|
||||
bias_hh_attr,
|
||||
is_bias=True,
|
||||
default_initializer=I.Uniform(-std, std))
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
if activation not in ["tanh", "relu", "brelu"]:
|
||||
raise ValueError(
|
||||
"activation for SimpleRNNCell should be tanh or relu, "
|
||||
"but get {}".format(activation))
|
||||
self.activation = activation
|
||||
self._activation_fn = paddle.tanh \
|
||||
if activation == "tanh" \
|
||||
else F.relu
|
||||
if activation == 'brelu':
|
||||
self._activation_fn = brelu
|
||||
|
||||
def forward(self, inputs, states=None):
|
||||
if states is None:
|
||||
states = self.get_initial_states(inputs, self.state_shape)
|
||||
pre_h = states
|
||||
i2h = inputs
|
||||
if self.bias_ih is not None:
|
||||
i2h += self.bias_ih
|
||||
h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True)
|
||||
if self.bias_hh is not None:
|
||||
h2h += self.bias_hh
|
||||
h = self._activation_fn(i2h + h2h)
|
||||
return h, h
|
||||
|
||||
@property
|
||||
def state_shape(self):
|
||||
return (self.hidden_size, )
|
||||
|
||||
|
||||
class GRUCell(nn.RNNCellBase):
|
||||
r"""
|
||||
Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states,
|
||||
it computes the outputs and updates states.
|
||||
The formula for GRU used is as follows:
|
||||
.. math::
|
||||
r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr})
|
||||
z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz})
|
||||
\widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))
|
||||
h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t}
|
||||
y_{t} & = h_{t}
|
||||
|
||||
where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
|
||||
multiplication operator.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_size: int,
|
||||
hidden_size: int,
|
||||
weight_ih_attr=None,
|
||||
weight_hh_attr=None,
|
||||
bias_ih_attr=None,
|
||||
bias_hh_attr=None,
|
||||
name=None):
|
||||
super().__init__()
|
||||
std = 1.0 / math.sqrt(hidden_size)
|
||||
self.weight_hh = self.create_parameter(
|
||||
(3 * hidden_size, hidden_size),
|
||||
weight_hh_attr,
|
||||
default_initializer=I.Uniform(-std, std))
|
||||
self.bias_ih = None
|
||||
self.bias_hh = self.create_parameter(
|
||||
(3 * hidden_size, ),
|
||||
bias_hh_attr,
|
||||
is_bias=True,
|
||||
default_initializer=I.Uniform(-std, std))
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.input_size = input_size
|
||||
self._gate_activation = F.sigmoid
|
||||
self._activation = paddle.relu
|
||||
|
||||
def forward(self, inputs, states=None):
|
||||
if states is None:
|
||||
states = self.get_initial_states(inputs, self.state_shape)
|
||||
|
||||
pre_hidden = states # shape [batch_size, hidden_size]
|
||||
|
||||
x_gates = inputs
|
||||
if self.bias_ih is not None:
|
||||
x_gates = x_gates + self.bias_ih
|
||||
bias_u, bias_r, bias_c = paddle.split(
|
||||
self.bias_hh, num_or_sections=3, axis=0)
|
||||
|
||||
weight_hh = paddle.transpose(
|
||||
self.weight_hh,
|
||||
perm=[1, 0]) #weight_hh:shape[hidden_size, 3 * hidden_size]
|
||||
w_u_r_c = paddle.flatten(weight_hh)
|
||||
size_u_r = self.hidden_size * 2 * self.hidden_size
|
||||
w_u_r = paddle.reshape(w_u_r_c[:size_u_r],
|
||||
(self.hidden_size, self.hidden_size * 2))
|
||||
w_u, w_r = paddle.split(w_u_r, num_or_sections=2, axis=1)
|
||||
w_c = paddle.reshape(w_u_r_c[size_u_r:],
|
||||
(self.hidden_size, self.hidden_size))
|
||||
|
||||
h_u = paddle.matmul(
|
||||
pre_hidden, w_u,
|
||||
transpose_y=False) + bias_u #shape [batch_size, hidden_size]
|
||||
h_r = paddle.matmul(
|
||||
pre_hidden, w_r,
|
||||
transpose_y=False) + bias_r #shape [batch_size, hidden_size]
|
||||
|
||||
x_u, x_r, x_c = paddle.split(
|
||||
x_gates, num_or_sections=3, axis=1) #shape[batch_size, hidden_size]
|
||||
|
||||
u = self._gate_activation(x_u + h_u) #shape [batch_size, hidden_size]
|
||||
r = self._gate_activation(x_r + h_r) #shape [batch_size, hidden_size]
|
||||
c = self._activation(
|
||||
x_c + paddle.matmul(r * pre_hidden, w_c, transpose_y=False) +
|
||||
bias_c) # [batch_size, hidden_size]
|
||||
|
||||
h = (1 - u) * pre_hidden + u * c
|
||||
# https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/fluid/layers/dynamic_gru_cn.html#dynamic-gru
|
||||
return h, h
|
||||
|
||||
@property
|
||||
def state_shape(self):
|
||||
r"""
|
||||
The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch
|
||||
size would be automatically inserted into shape). The shape corresponds
|
||||
to the shape of :math:`h_{t-1}`.
|
||||
"""
|
||||
return (self.hidden_size, )
|
||||
|
||||
|
||||
class BiRNNWithBN(nn.Layer):
|
||||
"""Bidirectonal simple rnn layer with sequence-wise batch normalization.
|
||||
The batch normalization is only performed on input-state weights.
|
||||
|
||||
:param size: Dimension of RNN cells.
|
||||
:type size: int
|
||||
:param share_weights: Whether to share input-hidden weights between
|
||||
forward and backward directional RNNs.
|
||||
:type share_weights: bool
|
||||
:return: Bidirectional simple rnn layer.
|
||||
:rtype: Variable
|
||||
"""
|
||||
|
||||
def __init__(self, i_size: int, h_size: int, share_weights: bool):
|
||||
super().__init__()
|
||||
self.share_weights = share_weights
|
||||
if self.share_weights:
|
||||
#input-hidden weights shared between bi-directional rnn.
|
||||
self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)
|
||||
# batch norm is only performed on input-state projection
|
||||
self.fw_bn = nn.BatchNorm1D(
|
||||
h_size, bias_attr=None, data_format='NLC')
|
||||
self.bw_fc = self.fw_fc
|
||||
self.bw_bn = self.fw_bn
|
||||
else:
|
||||
self.fw_fc = nn.Linear(i_size, h_size, bias_attr=False)
|
||||
self.fw_bn = nn.BatchNorm1D(
|
||||
h_size, bias_attr=None, data_format='NLC')
|
||||
self.bw_fc = nn.Linear(i_size, h_size, bias_attr=False)
|
||||
self.bw_bn = nn.BatchNorm1D(
|
||||
h_size, bias_attr=None, data_format='NLC')
|
||||
|
||||
self.fw_cell = RNNCell(hidden_size=h_size, activation='brelu')
|
||||
self.bw_cell = RNNCell(hidden_size=h_size, activation='brelu')
|
||||
self.fw_rnn = nn.RNN(
|
||||
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
|
||||
self.bw_rnn = nn.RNN(
|
||||
self.bw_cell, is_reverse=True, time_major=False) #[B, T, D]
|
||||
|
||||
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
|
||||
# x, shape [B, T, D]
|
||||
fw_x = self.fw_bn(self.fw_fc(x))
|
||||
bw_x = self.bw_bn(self.bw_fc(x))
|
||||
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
|
||||
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
|
||||
x = paddle.concat([fw_x, bw_x], axis=-1)
|
||||
return x, x_len
|
||||
|
||||
|
||||
class BiGRUWithBN(nn.Layer):
|
||||
"""Bidirectonal gru layer with sequence-wise batch normalization.
|
||||
The batch normalization is only performed on input-state weights.
|
||||
|
||||
:param name: Name of the layer.
|
||||
:type name: string
|
||||
:param input: Input layer.
|
||||
:type input: Variable
|
||||
:param size: Dimension of GRU cells.
|
||||
:type size: int
|
||||
:param act: Activation type.
|
||||
:type act: string
|
||||
:return: Bidirectional GRU layer.
|
||||
:rtype: Variable
|
||||
"""
|
||||
|
||||
def __init__(self, i_size: int, h_size: int):
|
||||
super().__init__()
|
||||
hidden_size = h_size * 3
|
||||
|
||||
self.fw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)
|
||||
self.fw_bn = nn.BatchNorm1D(
|
||||
hidden_size, bias_attr=None, data_format='NLC')
|
||||
self.bw_fc = nn.Linear(i_size, hidden_size, bias_attr=False)
|
||||
self.bw_bn = nn.BatchNorm1D(
|
||||
hidden_size, bias_attr=None, data_format='NLC')
|
||||
|
||||
self.fw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size)
|
||||
self.bw_cell = GRUCell(input_size=hidden_size, hidden_size=h_size)
|
||||
self.fw_rnn = nn.RNN(
|
||||
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
|
||||
self.bw_rnn = nn.RNN(
|
||||
self.bw_cell, is_reverse=True, time_major=False) #[B, T, D]
|
||||
|
||||
def forward(self, x, x_len):
|
||||
# x, shape [B, T, D]
|
||||
fw_x = self.fw_bn(self.fw_fc(x))
|
||||
|
||||
bw_x = self.bw_bn(self.bw_fc(x))
|
||||
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
|
||||
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
|
||||
x = paddle.concat([fw_x, bw_x], axis=-1)
|
||||
return x, x_len
|
||||
|
||||
|
||||
class RNNStack(nn.Layer):
|
||||
"""RNN group with stacked bidirectional simple RNN or GRU layers.
|
||||
|
||||
:param input: Input layer.
|
||||
:type input: Variable
|
||||
:param size: Dimension of RNN cells in each layer.
|
||||
:type size: int
|
||||
:param num_stacks: Number of stacked rnn layers.
|
||||
:type num_stacks: int
|
||||
:param use_gru: Use gru if set True. Use simple rnn if set False.
|
||||
:type use_gru: bool
|
||||
:param share_rnn_weights: Whether to share input-hidden weights between
|
||||
forward and backward directional RNNs.
|
||||
It is only available when use_gru=False.
|
||||
:type share_weights: bool
|
||||
:return: Output layer of the RNN group.
|
||||
:rtype: Variable
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
i_size: int,
|
||||
h_size: int,
|
||||
num_stacks: int,
|
||||
use_gru: bool,
|
||||
share_rnn_weights: bool):
|
||||
super().__init__()
|
||||
rnn_stacks = []
|
||||
for i in range(num_stacks):
|
||||
if use_gru:
|
||||
#default:GRU using tanh
|
||||
rnn_stacks.append(BiGRUWithBN(i_size=i_size, h_size=h_size))
|
||||
else:
|
||||
rnn_stacks.append(
|
||||
BiRNNWithBN(
|
||||
i_size=i_size,
|
||||
h_size=h_size,
|
||||
share_weights=share_rnn_weights))
|
||||
i_size = h_size * 2
|
||||
|
||||
self.rnn_stacks = nn.LayerList(rnn_stacks)
|
||||
|
||||
def forward(self, x: paddle.Tensor, x_len: paddle.Tensor):
|
||||
"""
|
||||
x: shape [B, T, D]
|
||||
x_len: shpae [B]
|
||||
"""
|
||||
for i, rnn in enumerate(self.rnn_stacks):
|
||||
x, x_len = rnn(x, x_len)
|
||||
masks = make_non_pad_mask(x_len) #[B, T]
|
||||
masks = masks.unsqueeze(-1) # [B, T, 1]
|
||||
# TODO(Hui Zhang): not support bool multiply
|
||||
masks = masks.astype(x.dtype)
|
||||
x = x.multiply(masks)
|
||||
return x, x_len
|
@ -1,357 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains DeepSpeech2 and DeepSpeech2Online model."""
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader
|
||||
from src_deepspeech2x.models.ds2 import DeepSpeech2InferModel
|
||||
from src_deepspeech2x.models.ds2 import DeepSpeech2Model
|
||||
|
||||
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
|
||||
from paddlespeech.s2t.io.collator import SpeechCollator
|
||||
from paddlespeech.s2t.io.dataset import ManifestDataset
|
||||
from paddlespeech.s2t.io.sampler import SortagradBatchSampler
|
||||
from paddlespeech.s2t.io.sampler import SortagradDistributedBatchSampler
|
||||
from paddlespeech.s2t.models.ds2_online import DeepSpeech2InferModelOnline
|
||||
from paddlespeech.s2t.models.ds2_online import DeepSpeech2ModelOnline
|
||||
from paddlespeech.s2t.training.gradclip import ClipGradByGlobalNormWithLog
|
||||
from paddlespeech.s2t.training.trainer import Trainer
|
||||
from paddlespeech.s2t.utils import error_rate
|
||||
from paddlespeech.s2t.utils import layer_tools
|
||||
from paddlespeech.s2t.utils import mp_tools
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
|
||||
class DeepSpeech2Trainer(Trainer):
|
||||
def __init__(self, config, args):
|
||||
super().__init__(config, args)
|
||||
|
||||
def train_batch(self, batch_index, batch_data, msg):
|
||||
train_conf = self.config
|
||||
start = time.time()
|
||||
|
||||
# forward
|
||||
utt, audio, audio_len, text, text_len = batch_data
|
||||
loss = self.model(audio, audio_len, text, text_len)
|
||||
losses_np = {
|
||||
'train_loss': float(loss),
|
||||
}
|
||||
|
||||
# loss backward
|
||||
if (batch_index + 1) % train_conf.accum_grad != 0:
|
||||
# Disable gradient synchronizations across DDP processes.
|
||||
# Within this context, gradients will be accumulated on module
|
||||
# variables, which will later be synchronized.
|
||||
context = self.model.no_sync
|
||||
else:
|
||||
# Used for single gpu training and DDP gradient synchronization
|
||||
# processes.
|
||||
context = nullcontext
|
||||
|
||||
with context():
|
||||
loss.backward()
|
||||
layer_tools.print_grads(self.model, print_func=None)
|
||||
|
||||
# optimizer step
|
||||
if (batch_index + 1) % train_conf.accum_grad == 0:
|
||||
self.optimizer.step()
|
||||
self.optimizer.clear_grad()
|
||||
self.iteration += 1
|
||||
|
||||
iteration_time = time.time() - start
|
||||
|
||||
msg += "train time: {:>.3f}s, ".format(iteration_time)
|
||||
msg += "batch size: {}, ".format(self.config.batch_size)
|
||||
msg += "accum: {}, ".format(train_conf.accum_grad)
|
||||
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in losses_np.items())
|
||||
logger.info(msg)
|
||||
|
||||
if dist.get_rank() == 0 and self.visualizer:
|
||||
for k, v in losses_np.items():
|
||||
# `step -1` since we update `step` after optimizer.step().
|
||||
self.visualizer.add_scalar("train/{}".format(k), v,
|
||||
self.iteration - 1)
|
||||
|
||||
@paddle.no_grad()
|
||||
def valid(self):
|
||||
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
|
||||
self.model.eval()
|
||||
valid_losses = defaultdict(list)
|
||||
num_seen_utts = 1
|
||||
total_loss = 0.0
|
||||
for i, batch in enumerate(self.valid_loader):
|
||||
utt, audio, audio_len, text, text_len = batch
|
||||
loss = self.model(audio, audio_len, text, text_len)
|
||||
if paddle.isfinite(loss):
|
||||
num_utts = batch[1].shape[0]
|
||||
num_seen_utts += num_utts
|
||||
total_loss += float(loss) * num_utts
|
||||
valid_losses['val_loss'].append(float(loss))
|
||||
|
||||
if (i + 1) % self.config.log_interval == 0:
|
||||
valid_dump = {k: np.mean(v) for k, v in valid_losses.items()}
|
||||
valid_dump['val_history_loss'] = total_loss / num_seen_utts
|
||||
|
||||
# logging
|
||||
msg = f"Valid: Rank: {dist.get_rank()}, "
|
||||
msg += "epoch: {}, ".format(self.epoch)
|
||||
msg += "step: {}, ".format(self.iteration)
|
||||
msg += "batch : {}/{}, ".format(i + 1, len(self.valid_loader))
|
||||
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in valid_dump.items())
|
||||
logger.info(msg)
|
||||
|
||||
logger.info('Rank {} Val info val_loss {}'.format(
|
||||
dist.get_rank(), total_loss / num_seen_utts))
|
||||
return total_loss, num_seen_utts
|
||||
|
||||
def setup_model(self):
|
||||
config = self.config.clone()
|
||||
config.defrost()
|
||||
config.feat_size = self.train_loader.collate_fn.feature_size
|
||||
#config.dict_size = self.train_loader.collate_fn.vocab_size
|
||||
config.dict_size = len(self.train_loader.collate_fn.vocab_list)
|
||||
config.freeze()
|
||||
|
||||
if self.args.model_type == 'offline':
|
||||
model = DeepSpeech2Model.from_config(config)
|
||||
elif self.args.model_type == 'online':
|
||||
model = DeepSpeech2ModelOnline.from_config(config)
|
||||
else:
|
||||
raise Exception("wrong model type")
|
||||
if self.parallel:
|
||||
model = paddle.DataParallel(model)
|
||||
|
||||
logger.info(f"{model}")
|
||||
layer_tools.print_params(model, logger.info)
|
||||
|
||||
grad_clip = ClipGradByGlobalNormWithLog(config.global_grad_clip)
|
||||
lr_scheduler = paddle.optimizer.lr.ExponentialDecay(
|
||||
learning_rate=config.lr, gamma=config.lr_decay, verbose=True)
|
||||
optimizer = paddle.optimizer.Adam(
|
||||
learning_rate=lr_scheduler,
|
||||
parameters=model.parameters(),
|
||||
weight_decay=paddle.regularizer.L2Decay(config.weight_decay),
|
||||
grad_clip=grad_clip)
|
||||
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
logger.info("Setup model/optimizer/lr_scheduler!")
|
||||
|
||||
def setup_dataloader(self):
|
||||
config = self.config.clone()
|
||||
config.defrost()
|
||||
config.keep_transcription_text = False
|
||||
|
||||
config.manifest = config.train_manifest
|
||||
train_dataset = ManifestDataset.from_config(config)
|
||||
|
||||
config.manifest = config.dev_manifest
|
||||
dev_dataset = ManifestDataset.from_config(config)
|
||||
|
||||
config.manifest = config.test_manifest
|
||||
test_dataset = ManifestDataset.from_config(config)
|
||||
|
||||
if self.parallel:
|
||||
batch_sampler = SortagradDistributedBatchSampler(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
num_replicas=None,
|
||||
rank=None,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
sortagrad=config.sortagrad,
|
||||
shuffle_method=config.shuffle_method)
|
||||
else:
|
||||
batch_sampler = SortagradBatchSampler(
|
||||
train_dataset,
|
||||
shuffle=True,
|
||||
batch_size=config.batch_size,
|
||||
drop_last=True,
|
||||
sortagrad=config.sortagrad,
|
||||
shuffle_method=config.shuffle_method)
|
||||
|
||||
collate_fn_train = SpeechCollator.from_config(config)
|
||||
|
||||
config.augmentation_config = ""
|
||||
collate_fn_dev = SpeechCollator.from_config(config)
|
||||
|
||||
config.keep_transcription_text = True
|
||||
config.augmentation_config = ""
|
||||
collate_fn_test = SpeechCollator.from_config(config)
|
||||
|
||||
self.train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=collate_fn_train,
|
||||
num_workers=config.num_workers)
|
||||
self.valid_loader = DataLoader(
|
||||
dev_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
collate_fn=collate_fn_dev)
|
||||
self.test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=config.decode.decode_batch_size,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
collate_fn=collate_fn_test)
|
||||
if "<eos>" in self.test_loader.collate_fn.vocab_list:
|
||||
self.test_loader.collate_fn.vocab_list.remove("<eos>")
|
||||
if "<eos>" in self.valid_loader.collate_fn.vocab_list:
|
||||
self.valid_loader.collate_fn.vocab_list.remove("<eos>")
|
||||
if "<eos>" in self.train_loader.collate_fn.vocab_list:
|
||||
self.train_loader.collate_fn.vocab_list.remove("<eos>")
|
||||
logger.info("Setup train/valid/test Dataloader!")
|
||||
|
||||
|
||||
class DeepSpeech2Tester(DeepSpeech2Trainer):
|
||||
def __init__(self, config, args):
|
||||
|
||||
self._text_featurizer = TextFeaturizer(
|
||||
unit_type=config.unit_type, vocab=None)
|
||||
super().__init__(config, args)
|
||||
|
||||
def ordid2token(self, texts, texts_len):
|
||||
""" ord() id to chr() chr """
|
||||
trans = []
|
||||
for text, n in zip(texts, texts_len):
|
||||
n = n.numpy().item()
|
||||
ids = text[:n]
|
||||
trans.append(''.join([chr(i) for i in ids]))
|
||||
return trans
|
||||
|
||||
def compute_metrics(self,
|
||||
utts,
|
||||
audio,
|
||||
audio_len,
|
||||
texts,
|
||||
texts_len,
|
||||
fout=None):
|
||||
cfg = self.config.decode
|
||||
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
||||
errors_func = error_rate.char_errors if cfg.error_rate_type == 'cer' else error_rate.word_errors
|
||||
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
|
||||
|
||||
target_transcripts = self.ordid2token(texts, texts_len)
|
||||
|
||||
result_transcripts = self.compute_result_transcripts(audio, audio_len)
|
||||
|
||||
for utt, target, result in zip(utts, target_transcripts,
|
||||
result_transcripts):
|
||||
errors, len_ref = errors_func(target, result)
|
||||
errors_sum += errors
|
||||
len_refs += len_ref
|
||||
num_ins += 1
|
||||
if fout:
|
||||
fout.write(utt + " " + result + "\n")
|
||||
logger.info("\nTarget Transcription: %s\nOutput Transcription: %s" %
|
||||
(target, result))
|
||||
logger.info("Current error rate [%s] = %f" %
|
||||
(cfg.error_rate_type, error_rate_func(target, result)))
|
||||
|
||||
return dict(
|
||||
errors_sum=errors_sum,
|
||||
len_refs=len_refs,
|
||||
num_ins=num_ins,
|
||||
error_rate=errors_sum / len_refs,
|
||||
error_rate_type=cfg.error_rate_type)
|
||||
|
||||
def compute_result_transcripts(self, audio, audio_len):
|
||||
result_transcripts = self.model.decode(audio, audio_len)
|
||||
|
||||
result_transcripts = [
|
||||
self._text_featurizer.detokenize(item)
|
||||
for item in result_transcripts
|
||||
]
|
||||
return result_transcripts
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
@paddle.no_grad()
|
||||
def test(self):
|
||||
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
|
||||
self.model.eval()
|
||||
cfg = self.config
|
||||
error_rate_type = None
|
||||
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
||||
|
||||
# Initialized the decoder in model
|
||||
decode_cfg = self.config.decode
|
||||
vocab_list = self.test_loader.collate_fn.vocab_list
|
||||
decode_batch_size = self.test_loader.batch_size
|
||||
self.model.decoder.init_decoder(
|
||||
decode_batch_size, vocab_list, decode_cfg.decoding_method,
|
||||
decode_cfg.lang_model_path, decode_cfg.alpha, decode_cfg.beta,
|
||||
decode_cfg.beam_size, decode_cfg.cutoff_prob,
|
||||
decode_cfg.cutoff_top_n, decode_cfg.num_proc_bsearch)
|
||||
|
||||
with open(self.args.result_file, 'w') as fout:
|
||||
for i, batch in enumerate(self.test_loader):
|
||||
utts, audio, audio_len, texts, texts_len = batch
|
||||
metrics = self.compute_metrics(utts, audio, audio_len, texts,
|
||||
texts_len, fout)
|
||||
errors_sum += metrics['errors_sum']
|
||||
len_refs += metrics['len_refs']
|
||||
num_ins += metrics['num_ins']
|
||||
error_rate_type = metrics['error_rate_type']
|
||||
logger.info("Error rate [%s] (%d/?) = %f" %
|
||||
(error_rate_type, num_ins, errors_sum / len_refs))
|
||||
|
||||
# logging
|
||||
msg = "Test: "
|
||||
msg += "epoch: {}, ".format(self.epoch)
|
||||
msg += "step: {}, ".format(self.iteration)
|
||||
msg += "Final error rate [%s] (%d/%d) = %f" % (
|
||||
error_rate_type, num_ins, num_ins, errors_sum / len_refs)
|
||||
logger.info(msg)
|
||||
self.model.decoder.del_decoder()
|
||||
|
||||
def run_test(self):
|
||||
self.resume_or_scratch()
|
||||
try:
|
||||
self.test()
|
||||
except KeyboardInterrupt:
|
||||
exit(-1)
|
||||
|
||||
def export(self):
|
||||
if self.args.model_type == 'offline':
|
||||
infer_model = DeepSpeech2InferModel.from_pretrained(
|
||||
self.test_loader, self.config, self.args.checkpoint_path)
|
||||
elif self.args.model_type == 'online':
|
||||
infer_model = DeepSpeech2InferModelOnline.from_pretrained(
|
||||
self.test_loader, self.config, self.args.checkpoint_path)
|
||||
else:
|
||||
raise Exception("wrong model type")
|
||||
|
||||
infer_model.eval()
|
||||
feat_dim = self.test_loader.collate_fn.feature_size
|
||||
static_model = infer_model.export()
|
||||
logger.info(f"Export code: {static_model.forward.code}")
|
||||
paddle.jit.save(static_model, self.args.export_path)
|
||||
|
||||
def run_export(self):
|
||||
try:
|
||||
self.export()
|
||||
except KeyboardInterrupt:
|
||||
exit(-1)
|
@ -1,151 +0,0 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
pretrained_models = {
|
||||
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
|
||||
# e.g. "conformer_wenetspeech-zh-16k" and "panns_cnn6-32k".
|
||||
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
|
||||
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
|
||||
"conformer_wenetspeech-zh-16k": {
|
||||
'url':
|
||||
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1_conformer_wenetspeech_ckpt_0.1.1.model.tar.gz',
|
||||
'md5':
|
||||
'76cb19ed857e6623856b7cd7ebbfeda4',
|
||||
'cfg_path':
|
||||
'model.yaml',
|
||||
'ckpt_path':
|
||||
'exp/conformer/checkpoints/wenetspeech',
|
||||
},
|
||||
"conformer_online_wenetspeech-zh-16k": {
|
||||
'url':
|
||||
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr1/asr1_chunk_conformer_wenetspeech_ckpt_1.0.0a.model.tar.gz',
|
||||
'md5':
|
||||
'b8c02632b04da34aca88459835be54a6',
|
||||
'cfg_path':
|
||||
'model.yaml',
|
||||
'ckpt_path':
|
||||
'exp/chunk_conformer/checkpoints/avg_10',
|
||||
},
|
||||
"conformer_online_multicn-zh-16k": {
|
||||
'url':
|
||||
'https://paddlespeech.bj.bcebos.com/s2t/multi_cn/asr1/asr1_chunk_conformer_multi_cn_ckpt_0.2.0.model.tar.gz',
|
||||
'md5':
|
||||
'7989b3248c898070904cf042fd656003',
|
||||
'cfg_path':
|
||||
'model.yaml',
|
||||
'ckpt_path':
|
||||
'exp/chunk_conformer/checkpoints/multi_cn',
|
||||
},
|
||||
"conformer_aishell-zh-16k": {
|
||||
'url':
|
||||
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_conformer_aishell_ckpt_0.1.2.model.tar.gz',
|
||||
'md5':
|
||||
'3f073eccfa7bb14e0c6867d65fc0dc3a',
|
||||
'cfg_path':
|
||||
'model.yaml',
|
||||
'ckpt_path':
|
||||
'exp/conformer/checkpoints/avg_30',
|
||||
},
|
||||
"conformer_online_aishell-zh-16k": {
|
||||
'url':
|
||||
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr1/asr1_chunk_conformer_aishell_ckpt_0.2.0.model.tar.gz',
|
||||
'md5':
|
||||
'b374cfb93537761270b6224fb0bfc26a',
|
||||
'cfg_path':
|
||||
'model.yaml',
|
||||
'ckpt_path':
|
||||
'exp/chunk_conformer/checkpoints/avg_30',
|
||||
},
|
||||
"transformer_librispeech-en-16k": {
|
||||
'url':
|
||||
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz',
|
||||
'md5':
|
||||
'2c667da24922aad391eacafe37bc1660',
|
||||
'cfg_path':
|
||||
'model.yaml',
|
||||
'ckpt_path':
|
||||
'exp/transformer/checkpoints/avg_10',
|
||||
},
|
||||
"deepspeech2online_wenetspeech-zh-16k": {
|
||||
'url':
|
||||
'https://paddlespeech.bj.bcebos.com/s2t/wenetspeech/asr0/asr0_deepspeech2_online_wenetspeech_ckpt_1.0.0a.model.tar.gz',
|
||||
'md5':
|
||||
'e393d4d274af0f6967db24fc146e8074',
|
||||
'cfg_path':
|
||||
'model.yaml',
|
||||
'ckpt_path':
|
||||
'exp/deepspeech2_online/checkpoints/avg_10',
|
||||
'lm_url':
|
||||
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
|
||||
'lm_md5':
|
||||
'29e02312deb2e59b3c8686c7966d4fe3'
|
||||
},
|
||||
"deepspeech2offline_aishell-zh-16k": {
|
||||
'url':
|
||||
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_aishell_ckpt_0.1.1.model.tar.gz',
|
||||
'md5':
|
||||
'932c3593d62fe5c741b59b31318aa314',
|
||||
'cfg_path':
|
||||
'model.yaml',
|
||||
'ckpt_path':
|
||||
'exp/deepspeech2/checkpoints/avg_1',
|
||||
'lm_url':
|
||||
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
|
||||
'lm_md5':
|
||||
'29e02312deb2e59b3c8686c7966d4fe3'
|
||||
},
|
||||
"deepspeech2online_aishell-zh-16k": {
|
||||
'url':
|
||||
'https://paddlespeech.bj.bcebos.com/s2t/aishell/asr0/asr0_deepspeech2_online_aishell_fbank161_ckpt_0.2.1.model.tar.gz',
|
||||
'md5':
|
||||
'98b87b171b7240b7cae6e07d8d0bc9be',
|
||||
'cfg_path':
|
||||
'model.yaml',
|
||||
'ckpt_path':
|
||||
'exp/deepspeech2_online/checkpoints/avg_1',
|
||||
'lm_url':
|
||||
'https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm',
|
||||
'lm_md5':
|
||||
'29e02312deb2e59b3c8686c7966d4fe3'
|
||||
},
|
||||
"deepspeech2offline_librispeech-en-16k": {
|
||||
'url':
|
||||
'https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr0/asr0_deepspeech2_librispeech_ckpt_0.1.1.model.tar.gz',
|
||||
'md5':
|
||||
'f5666c81ad015c8de03aac2bc92e5762',
|
||||
'cfg_path':
|
||||
'model.yaml',
|
||||
'ckpt_path':
|
||||
'exp/deepspeech2/checkpoints/avg_1',
|
||||
'lm_url':
|
||||
'https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm',
|
||||
'lm_md5':
|
||||
'099a601759d467cd0a8523ff939819c5'
|
||||
},
|
||||
}
|
||||
|
||||
model_alias = {
|
||||
"deepspeech2offline":
|
||||
"paddlespeech.s2t.models.ds2:DeepSpeech2Model",
|
||||
"deepspeech2online":
|
||||
"paddlespeech.s2t.models.ds2_online:DeepSpeech2ModelOnline",
|
||||
"conformer":
|
||||
"paddlespeech.s2t.models.u2:U2Model",
|
||||
"conformer_online":
|
||||
"paddlespeech.s2t.models.u2:U2Model",
|
||||
"transformer":
|
||||
"paddlespeech.s2t.models.u2:U2Model",
|
||||
"wenetspeech":
|
||||
"paddlespeech.s2t.models.u2:U2Model",
|
||||
}
|
@ -1,47 +0,0 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
pretrained_models = {
|
||||
# The tags for pretrained_models should be "{model_name}[_{dataset}][-{lang}][-...]".
|
||||
# e.g. "conformer_wenetspeech-zh-16k", "transformer_aishell-zh-16k" and "panns_cnn6-32k".
|
||||
# Command line and python api use "{model_name}[_{dataset}]" as --model, usage:
|
||||
# "paddlespeech asr --model conformer_wenetspeech --lang zh --sr 16000 --input ./input.wav"
|
||||
"panns_cnn6-32k": {
|
||||
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn6.tar.gz',
|
||||
'md5': '4cf09194a95df024fd12f84712cf0f9c',
|
||||
'cfg_path': 'panns.yaml',
|
||||
'ckpt_path': 'cnn6.pdparams',
|
||||
'label_file': 'audioset_labels.txt',
|
||||
},
|
||||
"panns_cnn10-32k": {
|
||||
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn10.tar.gz',
|
||||
'md5': 'cb8427b22176cc2116367d14847f5413',
|
||||
'cfg_path': 'panns.yaml',
|
||||
'ckpt_path': 'cnn10.pdparams',
|
||||
'label_file': 'audioset_labels.txt',
|
||||
},
|
||||
"panns_cnn14-32k": {
|
||||
'url': 'https://paddlespeech.bj.bcebos.com/cls/panns_cnn14.tar.gz',
|
||||
'md5': 'e3b9b5614a1595001161d0ab95edee97',
|
||||
'cfg_path': 'panns.yaml',
|
||||
'ckpt_path': 'cnn14.pdparams',
|
||||
'label_file': 'audioset_labels.txt',
|
||||
},
|
||||
}
|
||||
|
||||
model_alias = {
|
||||
"panns_cnn6": "paddlespeech.cls.models.panns:CNN6",
|
||||
"panns_cnn10": "paddlespeech.cls.models.panns:CNN10",
|
||||
"panns_cnn14": "paddlespeech.cls.models.panns:CNN14",
|
||||
}
|
@ -1,35 +0,0 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
pretrained_models = {
|
||||
"fat_st_ted-en-zh": {
|
||||
"url":
|
||||
"https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/st1_transformer_mtl_noam_ted-en-zh_ckpt_0.1.1.model.tar.gz",
|
||||
"md5":
|
||||
"d62063f35a16d91210a71081bd2dd557",
|
||||
"cfg_path":
|
||||
"model.yaml",
|
||||
"ckpt_path":
|
||||
"exp/transformer_mtl_noam/checkpoints/fat_st_ted-en-zh.pdparams",
|
||||
}
|
||||
}
|
||||
|
||||
model_alias = {"fat_st": "paddlespeech.s2t.models.u2_st:U2STModel"}
|
||||
|
||||
kaldi_bins = {
|
||||
"url":
|
||||
"https://paddlespeech.bj.bcebos.com/s2t/ted_en_zh/st1/kaldi_bins.tar.gz",
|
||||
"md5":
|
||||
"c0682303b3f3393dbf6ed4c4e35a53eb",
|
||||
}
|
@ -1,146 +0,0 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
from typing import List
|
||||
|
||||
from prettytable import PrettyTable
|
||||
|
||||
from ..utils import cli_register
|
||||
from ..utils import stats_wrapper
|
||||
|
||||
__all__ = ['StatsExecutor']
|
||||
|
||||
model_name_format = {
|
||||
'asr': 'Model-Language-Sample Rate',
|
||||
'cls': 'Model-Sample Rate',
|
||||
'st': 'Model-Source language-Target language',
|
||||
'text': 'Model-Task-Language',
|
||||
'tts': 'Model-Language',
|
||||
'vector': 'Model-Sample Rate'
|
||||
}
|
||||
|
||||
|
||||
@cli_register(
|
||||
name='paddlespeech.stats',
|
||||
description='Get speech tasks support models list.')
|
||||
class StatsExecutor():
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.parser = argparse.ArgumentParser(
|
||||
prog='paddlespeech.stats', add_help=True)
|
||||
self.task_choices = ['asr', 'cls', 'st', 'text', 'tts', 'vector']
|
||||
self.parser.add_argument(
|
||||
'--task',
|
||||
type=str,
|
||||
default='asr',
|
||||
choices=self.task_choices,
|
||||
help='Choose speech task.',
|
||||
required=True)
|
||||
|
||||
def show_support_models(self, pretrained_models: dict):
|
||||
fields = model_name_format[self.task].split("-")
|
||||
table = PrettyTable(fields)
|
||||
for key in pretrained_models:
|
||||
table.add_row(key.split("-"))
|
||||
print(table)
|
||||
|
||||
def execute(self, argv: List[str]) -> bool:
|
||||
"""
|
||||
Command line entry.
|
||||
"""
|
||||
parser_args = self.parser.parse_args(argv)
|
||||
has_exceptions = False
|
||||
try:
|
||||
self(parser_args.task)
|
||||
except Exception as e:
|
||||
has_exceptions = True
|
||||
if has_exceptions:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
@stats_wrapper
|
||||
def __call__(
|
||||
self,
|
||||
task: str=None, ):
|
||||
"""
|
||||
Python API to call an executor.
|
||||
"""
|
||||
self.task = task
|
||||
if self.task not in self.task_choices:
|
||||
print("Please input correct speech task, choices = " + str(
|
||||
self.task_choices))
|
||||
|
||||
elif self.task == 'asr':
|
||||
try:
|
||||
from ..asr.pretrained_models import pretrained_models
|
||||
print(
|
||||
"Here is the list of ASR pretrained models released by PaddleSpeech that can be used by command line and python API"
|
||||
)
|
||||
self.show_support_models(pretrained_models)
|
||||
except BaseException:
|
||||
print("Failed to get the list of ASR pretrained models.")
|
||||
|
||||
elif self.task == 'cls':
|
||||
try:
|
||||
from ..cls.pretrained_models import pretrained_models
|
||||
print(
|
||||
"Here is the list of CLS pretrained models released by PaddleSpeech that can be used by command line and python API"
|
||||
)
|
||||
self.show_support_models(pretrained_models)
|
||||
except BaseException:
|
||||
print("Failed to get the list of CLS pretrained models.")
|
||||
|
||||
elif self.task == 'st':
|
||||
try:
|
||||
from ..st.pretrained_models import pretrained_models
|
||||
print(
|
||||
"Here is the list of ST pretrained models released by PaddleSpeech that can be used by command line and python API"
|
||||
)
|
||||
self.show_support_models(pretrained_models)
|
||||
except BaseException:
|
||||
print("Failed to get the list of ST pretrained models.")
|
||||
|
||||
elif self.task == 'text':
|
||||
try:
|
||||
from ..text.pretrained_models import pretrained_models
|
||||
print(
|
||||
"Here is the list of TEXT pretrained models released by PaddleSpeech that can be used by command line and python API"
|
||||
)
|
||||
self.show_support_models(pretrained_models)
|
||||
except BaseException:
|
||||
print("Failed to get the list of TEXT pretrained models.")
|
||||
|
||||
elif self.task == 'tts':
|
||||
try:
|
||||
from ..tts.pretrained_models import pretrained_models
|
||||
print(
|
||||
"Here is the list of TTS pretrained models released by PaddleSpeech that can be used by command line and python API"
|
||||
)
|
||||
self.show_support_models(pretrained_models)
|
||||
except BaseException:
|
||||
print("Failed to get the list of TTS pretrained models.")
|
||||
|
||||
elif self.task == 'vector':
|
||||
try:
|
||||
from ..vector.pretrained_models import pretrained_models
|
||||
print(
|
||||
"Here is the list of Speaker Recognition pretrained models released by PaddleSpeech that can be used by command line and python API"
|
||||
)
|
||||
self.show_support_models(pretrained_models)
|
||||
except BaseException:
|
||||
print(
|
||||
"Failed to get the list of Speaker Recognition pretrained models."
|
||||
)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue