[ASR] add code-switch asr tal_cs recipe (#2796)

* add tal_cs asr recipe.

* add readme and result, and fix some bug.

* add commit id and date.
pull/2802/head
zxcd 2 years ago committed by GitHub
parent 25dcad3de7
commit e793d267d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,13 @@
# [TAL_CSASR](https://ai.100tal.com/dataset/)
This data set is TAL English class audio, including mixed Chinese and English speech. Each audio has only one speaker, and this data set has more than 100 speakers. (File 63.36G) This data contains the sample of intra sentence and inter sentence mixing. The ratio between Chinese characters and English words in the data is 13:1.
- Total data: 587H (train_set: 555.9H, dev_set: 8H, test_set: 23.6H)
- Sample rate: 16000
- Sample bit: 16
- Recording device: microphone
- Speaker number: 200+
- Recording time: 2019
- Data format: audio: .wav; test: .txt
- Audio duration: 1-60s
- Data type: audio of English teachers' teaching

@ -0,0 +1,116 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prepare TALCS ASR datasets.
create manifest files.
Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
"""
import argparse
import codecs
import io
import json
import os
import soundfile
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--target_dir",
type=str,
help="Directory to save the dataset. (default: %(default)s)")
parser.add_argument(
"--manifest_prefix",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
args = parser.parse_args()
TRAIN_SET = os.path.join(args.target_dir, "train_set")
DEV_SET = os.path.join(args.target_dir, "dev_set")
TEST_SET = os.path.join(args.target_dir, "test_set")
manifest_train_path = os.path.join(args.manifest_prefix, "manifest.train.raw")
manifest_dev_path = os.path.join(args.manifest_prefix, "manifest.dev.raw")
manifest_test_path = os.path.join(args.manifest_prefix, "manifest.test.raw")
def create_manifest(data_dir, manifest_path):
"""Create a manifest json file summarizing the data set, with each line
containing the meta data (i.e. audio filepath, transcription text, audio
duration) of each audio file within the data set.
"""
print("Creating manifest %s ..." % manifest_path)
json_lines = []
total_sec = 0.0
total_char = 0.0
total_num = 0
wav_dir = os.path.join(data_dir, 'wav')
text_filepath = os.path.join(data_dir, 'label.txt')
for subfolder, _, filelist in sorted(os.walk(wav_dir)):
for line in io.open(text_filepath, encoding="utf8"):
segments = line.strip().split()
nchars = len(segments[1:])
text = ' '.join(segments[1:]).lower()
audio_filepath = os.path.abspath(
os.path.join(subfolder, segments[0] + '.wav'))
audio_data, samplerate = soundfile.read(audio_filepath)
duration = float(len(audio_data)) / samplerate
utt = os.path.splitext(os.path.basename(audio_filepath))[0]
utt2spk = '-'.join(utt.split('-')[:2])
json_lines.append(
json.dumps({
'utt': utt,
'utt2spk': utt2spk,
'feat': audio_filepath,
'feat_shape': (duration, ), # second
'text': text,
}))
total_sec += duration
total_char += nchars
total_num += 1
with codecs.open(manifest_path, 'w', 'utf-8') as out_file:
for line in json_lines:
out_file.write(line + '\n')
subset = os.path.splitext(manifest_path)[1][1:]
manifest_dir = os.path.dirname(manifest_path)
data_dir_name = os.path.split(data_dir)[-1]
meta_path = os.path.join(manifest_dir, data_dir_name) + '.meta'
with open(meta_path, 'w') as f:
print(f"{subset}:", file=f)
print(f"{total_num} utts", file=f)
print(f"{total_sec / (60*60)} h", file=f)
print(f"{total_char} char", file=f)
print(f"{total_char / total_sec} char/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f)
def main():
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
create_manifest(TRAIN_SET, manifest_train_path)
create_manifest(DEV_SET, manifest_dev_path)
create_manifest(TEST_SET, manifest_test_path)
print("Data download and manifest prepare done!")
if __name__ == '__main__':
main()

@ -17,6 +17,7 @@ Acoustic Model | Training Data | Token-based | Size | Descriptions | CER | WER |
[Conformer Librispeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_conformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 191 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring |-| 0.0338 | 960 h | [Conformer Librispeech ASR1](../../examples/librispeech/asr1) | python |
[Transformer Librispeech ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr1/asr1_transformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 131 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: Attention rescoring |-| 0.0381 | 960 h | [Transformer Librispeech ASR1](../../examples/librispeech/asr1) | python |
[Transformer Librispeech ASR2 Model](https://paddlespeech.bj.bcebos.com/s2t/librispeech/asr2/asr2_transformer_librispeech_ckpt_0.1.1.model.tar.gz) | Librispeech Dataset | subword-based | 131 MB | Encoder:Transformer, Decoder:Transformer, Decoding method: JoinCTC w/ LM |-| 0.0240 | 960 h | [Transformer Librispeech ASR2](../../examples/librispeech/asr2) | python |
[Conformer TALCS ASR1 Model](https://paddlespeech.bj.bcebos.com/s2t/tal_cs/asr1/asr1_conformer_talcs_ckpt_1.4.0.model.tar.gz) | TALCS Dataset | subword-based | 470 MB | Encoder:Conformer, Decoder:Transformer, Decoding method: Attention rescoring |-| 0.0844 | 587 h | [Conformer TALCS ASR1](../../examples/tal_cs/asr1) | python |
### Self-Supervised Pre-trained Model
Model | Pre-Train Method | Pre-Train Data | Finetune Data | Size | Descriptions | CER | WER | Example Link |
@ -29,7 +30,7 @@ Model | Pre-Train Method | Pre-Train Data | Finetune Data | Size | Descriptions
### Whisper Model
Demo Link | Training Data | Size | Descriptions | CER | Model
:-----------: | :-----:| :-------: | :-----: | :-----: |:---------:|
[Whisper](../../demos/whisper) | 680kh from internet | large: 5.8G,</br>medium: 2.9G,</br>small: 923M,</br>base: 277M,</br>tiny: 145M | Encoder:Transformer,</br> Decoder:Transformer, </br>Decoding method: </br>Greedy search | 2.7 </br>(large, Librispeech) | [whisper-large](https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-large-model.tar.gz) </br>[whisper-medium](https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-medium-model.tar.gz) </br>[whisper-medium-English-only](https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-medium-en-model.tar.gz) </br>[whisper-small](https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-small-model.tar.gz) </br>[whisper-small-English-only](https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-small-en-model.tar.gz) </br>[whisper-base](https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-base-model.tar.gz) </br>[whisper-base-English-only](https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-base-en-model.tar.gz) </br>[whisper-tiny](https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-tiny-model.tar.gz) </br>[whisper-tiny-English-only](https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-tiny-en-model.tar.gz)
[Whisper](../../demos/whisper) | 680kh from internet | large: 5.8G,</br>medium: 2.9G,</br>small: 923M,</br>base: 277M,</br>tiny: 145M | Encoder:Transformer,</br> Decoder:Transformer, </br>Decoding method: </br>Greedy search | 0.027 </br>(large, Librispeech) | [whisper-large](https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-large-model.tar.gz) </br>[whisper-medium](https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-medium-model.tar.gz) </br>[whisper-medium-English-only](https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-medium-en-model.tar.gz) </br>[whisper-small](https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-small-model.tar.gz) </br>[whisper-small-English-only](https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-small-en-model.tar.gz) </br>[whisper-base](https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-base-model.tar.gz) </br>[whisper-base-English-only](https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-base-en-model.tar.gz) </br>[whisper-tiny](https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-tiny-model.tar.gz) </br>[whisper-tiny-English-only](https://paddlespeech.bj.bcebos.com/whisper/whisper_model_20221122/whisper-tiny-en-model.tar.gz)
### Language Model based on NGram
|Language Model | Training Data | Token-based | Size | Descriptions|

@ -0,0 +1,190 @@
# Transformer/Conformer ASR with TALCS
This example contains code used to train [u2](https://arxiv.org/pdf/2012.05481.pdf) model (Transformer or [Conformer](https://arxiv.org/pdf/2005.08100.pdf) model) with [TALCS dataset](https://ai.100tal.com/dataset)
## Overview
All the scripts you need are in `run.sh`. There are several stages in `run.sh`, and each stage has its function.
| Stage | Function |
|:---- |:----------------------------------------------------------- |
| 0 | Process data. It includes: <br> (1) Download the dataset <br> (2) Calculate the CMVN of the train dataset <br> (3) Get the vocabulary file <br> (4) Get the manifest files of the train, development and test dataset<br> (5) Get the sentencepiece model |
| 1 | Train the model |
| 2 | Get the final model by averaging the top-k models, set k = 1 means to choose the best model |
| 3 | Test the final model performance |
| 4 | Get ctc alignment of test data using the final model |
| 5 | Infer the single audio file |
You can choose to run a range of stages by setting `stage` and `stop_stage `.
For example, if you want to execute the code in stage 2 and stage 3, you can run this script:
```bash
bash run.sh --stage 2 --stop_stage 3
```
Or you can set `stage` equal to `stop-stage` to only run one stage.
For example, if you only want to run `stage 0`, you can use the script below:
```bash
bash run.sh --stage 0 --stop_stage 0
```
The document below will describe the scripts in `run.sh` in detail.
## The Environment Variables
The path.sh contains the environment variables.
```bash
. ./path.sh
. ./cmd.sh
```
This script needs to be run first. And another script is also needed:
```bash
source ${MAIN_ROOT}/utils/parse_options.sh
```
It will support the way of using `--variable value` in the shell scripts.
## The Local Variables
Some local variables are set in `run.sh`.
`gpus` denotes the GPU number you want to use. If you set `gpus=`, it means you only use CPU.
`stage` denotes the number of stages you want to start from in the experiments.
`stop stage` denotes the number of the stage you want to end at in the experiments.
`conf_path` denotes the config path of the model.
`avg_num` denotes the number K of top-K models you want to average to get the final model.
`audio file` denotes the file path of the single file you want to infer in stage 5
`ckpt` denotes the checkpoint prefix of the model, e.g. "conformer"
You can set the local variables (except `ckpt`) when you use `run.sh`
For example, you can set the `gpus` and `avg_num` when you use the command line:
```bash
bash run.sh --gpus 0,1 --avg_num 10
```
## Stage 0: Data Processing
To use this example, you need to process data firstly and you can use stage 0 in `run.sh` to do this. The code is shown below:
```bash
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
bash ./local/data.sh || exit -1
fi
```
Stage 0 is for processing the data.
If you only want to process the data. You can run
```bash
bash run.sh --stage 0 --stop_stage 0
```
You can also just run these scripts in your command line.
```bash
. ./path.sh
. ./cmd.sh
bash ./local/data.sh
```
After processing the data, the `data` directory will look like this:
```bash
data/
|-- dev_set.meta
|-- lang_char
| `-- bpe_bpe_11297.model
| `-- bpe_bpe_11297.vocab
| `-- vocab.txt
|-- manifest.dev
|-- manifest.dev.raw
|-- manifest.test
|-- manifest.test.raw
|-- manifest.train
|-- manifest.train.raw
|-- mean_std.json
|-- test_set.meta
`-- train_set.meta
```
## Stage 1: Model Training
If you want to train the model. you can use stage 1 in `run.sh`. The code is shown below.
```bash
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt}
fi
```
If you want to train the model, you can use the script below to execute stage 0 and stage 1:
```bash
bash run.sh --stage 0 --stop_stage 1
```
or you can run these scripts in the command line (only use CPU).
```bash
. ./path.sh
. ./cmd.sh
bash ./local/data.sh
CUDA_VISIBLE_DEVICES= ./local/train.sh conf/conformer.yaml conformer
```
## Stage 2: Top-k Models Averaging
After training the model, we need to get the final model for testing and inference. In every epoch, the model checkpoint is saved, so we can choose the best model from them based on the validation loss or we can sort them and average the parameters of the top-k models to get the final model. We can use stage 2 to do this, and the code is shown below:
```bash
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model
avg.sh best exp/${ckpt}/checkpoints ${avg_num}
fi
```
The `avg.sh` is in the `../../../utils/` which is define in the `path.sh`.
If you want to get the final model, you can use the script below to execute stage 0, stage 1, and stage 2:
```bash
bash run.sh --stage 0 --stop_stage 2
```
or you can run these scripts in the command line (only use CPU).
```bash
. ./path.sh
. ./cmd.sh
bash ./local/data.sh
CUDA_VISIBLE_DEVICES= ./local/train.sh conf/conformer.yaml conformer
avg.sh best exp/conformer/checkpoints 10
```
## Stage 3: Model Testing
The test stage is to evaluate the model performance. The code of test stage is shown below:
```bash
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
```
If you want to train a model and test it, you can use the script below to execute stage 0, stage 1, stage 2, and stage 3 :
```bash
bash run.sh --stage 0 --stop_stage 3
```
or you can run these scripts in the command line (only use CPU).
```bash
. ./path.sh
. ./cmd.sh
bash ./local/data.sh
CUDA_VISIBLE_DEVICES= ./local/train.sh conf/conformer.yaml conformer
avg.sh best exp/conformer/checkpoints 10
CUDA_VISIBLE_DEVICES= ./local/test.sh conf/conformer.yaml exp/conformer/checkpoints/avg_10
```
## Pretrained Model
You can get the pretrained transformer or conformer from [this](../../../docs/source/released_model.md).
using the `tar` scripts to unpack the model and then you can use the script to test the model.
For example:
```bash
wget https://paddlespeech.bj.bcebos.com/s2t/tal_cs/asr1/asr1_conformer_talcs_ckpt_1.4.0.model.tar.gz
tar xzvf asr1_conformer_talcs_ckpt_1.4.0.model.tar.gz
source path.sh
# If you have process the data and get the manifest file you can skip the following 2 steps
bash local/data.sh --stage -1 --stop_stage -1
bash local/data.sh --stage 2 --stop_stage 2
CUDA_VISIBLE_DEVICES= ./local/test.sh conf/conformer.yaml exp/conformer/checkpoints/avg_10
```
The performance of the released models are shown in [here](./RESULTS.md).
## Stage 5: Single Audio File Inference
In some situations, you want to use the trained model to do the inference for the single audio file. You can use stage 5. The code is shown below
```bash
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# test a single .wav file
CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1
fi
```
you can train the model by yourself using ```bash run.sh --stage 0 --stop_stage 3```, or you can download the pretrained model through the script below:
```bash
wget https://paddlespeech.bj.bcebos.com/s2t/tal_cs/asr1/asr1_conformer_talcs_ckpt_1.4.0.model.tar.gz
tar xzvf asr1_conformer_talcs_ckpt_1.4.0.model.tar.gz
```
You can download the audio demo:
```bash
wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/zh/demo_01_03.wav -P data/
```
You need to prepare an audio file or use the audio demo above, please confirm the sample rate of the audio is 16K. You can get the result of the audio demo by running the script below.
```bash
CUDA_VISIBLE_DEVICES= ./local/test_wav.sh conf/conformer.yaml exp/conformer/checkpoints/avg_10 data/demo_01_03.wav
```

@ -0,0 +1,12 @@
# TALCS
2023.1.6, commit id: fa724285f3b799b97b4348ad3b1084afc0764f9b
## Conformer
train: Epoch 100, 3 V100-32G, best avg: 10
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug | test-set | attention | 9.85091028213501 | 0.102786 |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug | test-set | ctc_greedy_search | 9.85091028213501 | 0.103538 |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug | test-set | ctc_prefix_beam_search | 9.85091028213501 | 0.103317 |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug | test-set | attention_rescoring | 9.85091028213501 | 0.084374 |

@ -0,0 +1,91 @@
############################################
# Network Architecture #
############################################
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: conformer
encoder_conf:
output_size: 512 # dimension of attention
attention_heads: 8
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: True
cnn_module_kernel: 15
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 8
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
init_type: 'kaiming_uniform' # !Warning: need to convergence
###########################################
# Data #
###########################################
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
###########################################
# Dataloader #
###########################################
vocab_filepath: data/lang_char/vocab.txt
spm_model_prefix: 'data/lang_char/bpe_bpe_11297'
unit_type: 'spm'
preprocess_config: conf/preprocess.yaml
feat_dim: 80
stride_ms: 20.0
window_ms: 30.0
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
batch_size: 5
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
minibatches: 0 # for debug
batch_count: auto
batch_bins: 0
batch_frames_in: 0
batch_frames_out: 0
batch_frames_inout: 0
num_workers: 2
subsampling_factor: 1
num_encs: 1
###########################################
# Training #
###########################################
n_epoch: 100
accum_grad: 4
global_grad_clip: 5.0
dist_sampler: False
optim: adam
optim_conf:
lr: 0.002
weight_decay: 1.0e-6
scheduler: warmuplr
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5

@ -0,0 +1,29 @@
process:
# extract kaldi fbank from PCM
- type: fbank_kaldi
fs: 16000
n_mels: 80
n_shift: 160
win_length: 400
dither: 1.0
- type: cmvn_json
cmvn_path: data/mean_std.json
# these three processes are a.k.a. SpecAugument
- type: time_warp
max_time_warp: 5
inplace: true
mode: PIL
- type: freq_mask
F: 30
n_mask: 2
inplace: true
replace_with_zero: false
- type: time_mask
T: 40
n_mask: 2
inplace: true
replace_with_zero: false

@ -0,0 +1,12 @@
beam_size: 10
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
reverse_weight: 0.3 # reverse weight for attention rescoring decode mode.
decoding_chunk_size: 16 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: True # simulate streaming inference. Defaults to False.
decode_batch_size: 128
error_rate_type: cer

@ -0,0 +1,12 @@
beam_size: 10
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
#reverse_weight: 0.3 # reverse weight for attention rescoring decode mode.
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: False # simulate streaming inference. Defaults to False.
decode_batch_size: 1
error_rate_type: cer

@ -0,0 +1,88 @@
#!/bin/bash
stage=-1
stop_stage=100
dict_dir=data/lang_char
# bpemode (unigram or bpe)
nbpe=11297
bpemode=bpe
bpeprefix="${dict_dir}/bpe_${bpemode}_${nbpe}"
stride_ms=20
window_ms=30
sample_rate=16000
feat_dim=80
source ${MAIN_ROOT}/utils/parse_options.sh
mkdir -p data
mkdir -p ${dict_dir}
TARGET_DIR=${MAIN_ROOT}/dataset
mkdir -p ${TARGET_DIR}
#prepare data
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
if [ ! -d "${MAIN_ROOT}/dataset/tal_cs/TALCS_corpus" ]; then
echo "${MAIN_ROOT}/dataset/tal_cs/TALCS_corpus does not exist. Please donwload tal_cs data and unpack it from https://ai.100tal.com/dataset first."
echo "data md5 reference: 4c879b3c9c05365fc9dee1fc68713afe"
exit
fi
# create manifest json file from TALCS_corpus
python ${MAIN_ROOT}/dataset/tal_cs/tal_cs.py --target_dir ${MAIN_ROOT}/dataset/tal_cs/TALCS_corpus/ --manifest_prefix data/
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# compute mean and stddev for normalizer
num_workers=$(nproc)
python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
--manifest_path="data/manifest.train.raw" \
--num_samples=-1 \
--spectrum_type="fbank" \
--feat_dim=${feat_dim} \
--delta_delta=false \
--sample_rate=${sample_rate} \
--stride_ms=${stride_ms} \
--window_ms=${window_ms} \
--use_dB_normalization=False \
--num_workers=${num_workers} \
--output_path="data/mean_std.json"
echo "compute mean and stddev done."
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
#use train_set build dict
python3 ${MAIN_ROOT}/utils/build_vocab.py \
--unit_type 'spm' \
--count_threshold=0 \
--vocab_path="${dict_dir}/vocab.txt" \
--manifest_paths="data/manifest.train.raw" \
--spm_mode=${bpemode} \
--spm_vocab_size=${nbpe} \
--spm_model_prefix=${bpeprefix} \
--spm_character_coverage=1
echo "build dict done."
fi
#use new dict format data
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# format manifest with tokenids, vocab size
for sub in train dev test ; do
{
python3 ${MAIN_ROOT}/utils/format_data.py \
--cmvn_path "data/mean_std.json" \
--unit_type "spm" \
--spm_model_prefix ${bpeprefix} \
--vocab_path="${dict_dir}/vocab.txt" \
--manifest_path="data/manifest.${sub}.raw" \
--output_path="data/manifest.${sub}"
if [ $? -ne 0 ]; then
echo "Formt mnaifest failed. Terminated."
exit 1
fi
}&
done
wait
echo "format data done."
fi

@ -0,0 +1,72 @@
#!/bin/bash
if [ $# != 3 ];then
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
decode_config_path=$2
ckpt_prefix=$3
chunk_mode=false
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
chunk_mode=true
fi
# download language model
#bash local/download_lm_ch.sh
#if [ $? -ne 0 ]; then
# exit 1
#fi
for type in attention ctc_greedy_search; do
echo "decoding ${type}"
if [ ${chunk_mode} == true ];then
# stream decoding only support batchsize=1
batch_size=1
else
batch_size=64
fi
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
python3 -u ${BIN_DIR}/test.py \
--ngpu ${ngpu} \
--config ${config_path} \
--decode_cfg ${decode_config_path} \
--result_file ${output_dir}/${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decode.decoding_method ${type} \
--opts decode.decode_batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
for type in ctc_prefix_beam_search attention_rescoring; do
echo "decoding ${type}"
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
python3 -u ${BIN_DIR}/test.py \
--ngpu ${ngpu} \
--config ${config_path} \
--decode_cfg ${decode_config_path} \
--result_file ${output_dir}/${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decode.decoding_method ${type} \
--opts decode.decode_batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
exit 0

@ -0,0 +1,58 @@
#!/bin/bash
if [ $# != 4 ];then
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix audio_file"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
config_path=$1
decode_config_path=$2
ckpt_prefix=$3
audio_file=$4
mkdir -p data
wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/zh/demo_01_03.wav -P data/
if [ $? -ne 0 ]; then
exit 1
fi
if [ ! -f ${audio_file} ]; then
echo "Plase input the right audio_file path"
exit 1
fi
chunk_mode=false
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
chunk_mode=true
fi
# download language model
#bash local/download_lm_ch.sh
#if [ $? -ne 0 ]; then
# exit 1
#fi
for type in attention_rescoring; do
echo "decoding ${type}"
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
python3 -u ${BIN_DIR}/test_wav.py \
--ngpu ${ngpu} \
--config ${config_path} \
--decode_cfg ${decode_config_path} \
--result_file ${output_dir}/${type}.rsl \
--checkpoint_path ${ckpt_prefix} \
--opts decode.decoding_method ${type} \
--opts decode.decode_batch_size ${batch_size} \
--audio_file ${audio_file}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
exit 0

@ -0,0 +1,72 @@
#!/bin/bash
profiler_options=
benchmark_batch_size=0
benchmark_max_step=0
# seed may break model convergence
seed=0
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
if [ ${seed} != 0 ]; then
export FLAGS_cudnn_deterministic=True
echo "using seed $seed & FLAGS_cudnn_deterministic=True ..."
fi
if [ $# -lt 2 ] && [ $# -gt 3 ];then
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)"
exit -1
fi
config_path=$1
ckpt_name=$2
ips=$3
if [ ! $ips ];then
ips_config=
else
ips_config="--ips="${ips}
fi
echo ${ips_config}
mkdir -p exp
# default memeory allocator strategy may case gpu training hang
# for no OOM raised when memory exhaused
export FLAGS_allocator_strategy=naive_best_fit
if [ ${ngpu} == 0 ]; then
python3 -u ${BIN_DIR}/train.py \
--ngpu ${ngpu} \
--seed ${seed} \
--config ${config_path} \
--output exp/${ckpt_name} \
--profiler-options "${profiler_options}" \
--benchmark-batch-size ${benchmark_batch_size} \
--benchmark-max-step ${benchmark_max_step}
else
python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
--ngpu ${ngpu} \
--seed ${seed} \
--config ${config_path} \
--output exp/${ckpt_name} \
--profiler-options "${profiler_options}" \
--benchmark-batch-size ${benchmark_batch_size} \
--benchmark-max-step ${benchmark_max_step}
fi
if [ ${seed} != 0 ]; then
unset FLAGS_cudnn_deterministic
fi
if [ $? -ne 0 ]; then
echo "Failed in training!"
exit 1
fi
exit 0

@ -0,0 +1,15 @@
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
# model exp
MODEL=u2
export BIN_DIR=${MAIN_ROOT}/paddlespeech/s2t/exps/${MODEL}/bin

@ -0,0 +1,51 @@
#!/bin/bash
source path.sh || exit 1;
set -e
gpus=0,1,2,3
stage=0
stop_stage=50
conf_path=conf/conformer.yaml
ips= #xxx.xxx.xxx.xxx,xxx.xxx.xxx.xxx
decode_conf_path=conf/tuning/decode.yaml
average_checkpoint=true
avg_num=10
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
avg_ckpt=avg_${avg_num}
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
echo "checkpoint name ${ckpt}"
audio_file="data/demo_01_03.wav"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
bash ./local/data.sh || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `exp` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# avg n best model
avg.sh best exp/${ckpt}/checkpoints ${avg_num}
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# test a single .wav file
CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1
fi
# Not supported at now!!!
if [ ${stage} -le 51 ] && [ ${stop_stage} -ge 51 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES=0 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi

@ -0,0 +1 @@
../../../utils
Loading…
Cancel
Save