[ASR] add code-switch asr tal_cs recipe (#2796)
* add tal_cs asr recipe. * add readme and result, and fix some bug. * add commit id and date.pull/2802/head
parent
25dcad3de7
commit
e793d267d9
@ -0,0 +1,13 @@
|
||||
# [TAL_CSASR](https://ai.100tal.com/dataset/)
|
||||
|
||||
This data set is TAL English class audio, including mixed Chinese and English speech. Each audio has only one speaker, and this data set has more than 100 speakers. (File 63.36G) This data contains the sample of intra sentence and inter sentence mixing. The ratio between Chinese characters and English words in the data is 13:1.
|
||||
|
||||
- Total data: 587H (train_set: 555.9H, dev_set: 8H, test_set: 23.6H)
|
||||
- Sample rate: 16000
|
||||
- Sample bit: 16
|
||||
- Recording device: microphone
|
||||
- Speaker number: 200+
|
||||
- Recording time: 2019
|
||||
- Data format: audio: .wav; test: .txt
|
||||
- Audio duration: 1-60s
|
||||
- Data type: audio of English teachers' teaching
|
@ -0,0 +1,116 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Prepare TALCS ASR datasets.
|
||||
|
||||
create manifest files.
|
||||
Manifest file is a json-format file with each line containing the
|
||||
meta data (i.e. audio filepath, transcript and audio duration)
|
||||
of each audio file in the data set.
|
||||
"""
|
||||
import argparse
|
||||
import codecs
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
|
||||
import soundfile
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--target_dir",
|
||||
type=str,
|
||||
help="Directory to save the dataset. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--manifest_prefix",
|
||||
type=str,
|
||||
help="Filepath prefix for output manifests. (default: %(default)s)")
|
||||
args = parser.parse_args()
|
||||
|
||||
TRAIN_SET = os.path.join(args.target_dir, "train_set")
|
||||
DEV_SET = os.path.join(args.target_dir, "dev_set")
|
||||
TEST_SET = os.path.join(args.target_dir, "test_set")
|
||||
|
||||
manifest_train_path = os.path.join(args.manifest_prefix, "manifest.train.raw")
|
||||
manifest_dev_path = os.path.join(args.manifest_prefix, "manifest.dev.raw")
|
||||
manifest_test_path = os.path.join(args.manifest_prefix, "manifest.test.raw")
|
||||
|
||||
|
||||
def create_manifest(data_dir, manifest_path):
|
||||
"""Create a manifest json file summarizing the data set, with each line
|
||||
containing the meta data (i.e. audio filepath, transcription text, audio
|
||||
duration) of each audio file within the data set.
|
||||
"""
|
||||
print("Creating manifest %s ..." % manifest_path)
|
||||
json_lines = []
|
||||
total_sec = 0.0
|
||||
total_char = 0.0
|
||||
total_num = 0
|
||||
wav_dir = os.path.join(data_dir, 'wav')
|
||||
text_filepath = os.path.join(data_dir, 'label.txt')
|
||||
for subfolder, _, filelist in sorted(os.walk(wav_dir)):
|
||||
for line in io.open(text_filepath, encoding="utf8"):
|
||||
segments = line.strip().split()
|
||||
nchars = len(segments[1:])
|
||||
text = ' '.join(segments[1:]).lower()
|
||||
|
||||
audio_filepath = os.path.abspath(
|
||||
os.path.join(subfolder, segments[0] + '.wav'))
|
||||
audio_data, samplerate = soundfile.read(audio_filepath)
|
||||
duration = float(len(audio_data)) / samplerate
|
||||
|
||||
utt = os.path.splitext(os.path.basename(audio_filepath))[0]
|
||||
utt2spk = '-'.join(utt.split('-')[:2])
|
||||
|
||||
json_lines.append(
|
||||
json.dumps({
|
||||
'utt': utt,
|
||||
'utt2spk': utt2spk,
|
||||
'feat': audio_filepath,
|
||||
'feat_shape': (duration, ), # second
|
||||
'text': text,
|
||||
}))
|
||||
|
||||
total_sec += duration
|
||||
total_char += nchars
|
||||
total_num += 1
|
||||
|
||||
with codecs.open(manifest_path, 'w', 'utf-8') as out_file:
|
||||
for line in json_lines:
|
||||
out_file.write(line + '\n')
|
||||
|
||||
subset = os.path.splitext(manifest_path)[1][1:]
|
||||
manifest_dir = os.path.dirname(manifest_path)
|
||||
data_dir_name = os.path.split(data_dir)[-1]
|
||||
meta_path = os.path.join(manifest_dir, data_dir_name) + '.meta'
|
||||
with open(meta_path, 'w') as f:
|
||||
print(f"{subset}:", file=f)
|
||||
print(f"{total_num} utts", file=f)
|
||||
print(f"{total_sec / (60*60)} h", file=f)
|
||||
print(f"{total_char} char", file=f)
|
||||
print(f"{total_char / total_sec} char/sec", file=f)
|
||||
print(f"{total_sec / total_num} sec/utt", file=f)
|
||||
|
||||
|
||||
def main():
|
||||
if args.target_dir.startswith('~'):
|
||||
args.target_dir = os.path.expanduser(args.target_dir)
|
||||
|
||||
create_manifest(TRAIN_SET, manifest_train_path)
|
||||
create_manifest(DEV_SET, manifest_dev_path)
|
||||
create_manifest(TEST_SET, manifest_test_path)
|
||||
print("Data download and manifest prepare done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,12 @@
|
||||
# TALCS
|
||||
2023.1.6, commit id: fa724285f3b799b97b4348ad3b1084afc0764f9b
|
||||
|
||||
## Conformer
|
||||
train: Epoch 100, 3 V100-32G, best avg: 10
|
||||
|
||||
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
|
||||
| --- | --- | --- | --- | --- | --- | --- | --- |
|
||||
| conformer | 47.63 M | conf/conformer.yaml | spec_aug | test-set | attention | 9.85091028213501 | 0.102786 |
|
||||
| conformer | 47.63 M | conf/conformer.yaml | spec_aug | test-set | ctc_greedy_search | 9.85091028213501 | 0.103538 |
|
||||
| conformer | 47.63 M | conf/conformer.yaml | spec_aug | test-set | ctc_prefix_beam_search | 9.85091028213501 | 0.103317 |
|
||||
| conformer | 47.63 M | conf/conformer.yaml | spec_aug | test-set | attention_rescoring | 9.85091028213501 | 0.084374 |
|
@ -0,0 +1,91 @@
|
||||
############################################
|
||||
# Network Architecture #
|
||||
############################################
|
||||
cmvn_file:
|
||||
cmvn_file_type: "json"
|
||||
# encoder related
|
||||
encoder: conformer
|
||||
encoder_conf:
|
||||
output_size: 512 # dimension of attention
|
||||
attention_heads: 8
|
||||
linear_units: 2048 # the number of units of position-wise feed forward
|
||||
num_blocks: 12 # the number of encoder blocks
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.0
|
||||
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
|
||||
normalize_before: True
|
||||
cnn_module_kernel: 15
|
||||
use_cnn_module: True
|
||||
activation_type: 'swish'
|
||||
pos_enc_layer_type: 'rel_pos'
|
||||
selfattention_layer_type: 'rel_selfattn'
|
||||
|
||||
# decoder related
|
||||
decoder: transformer
|
||||
decoder_conf:
|
||||
attention_heads: 8
|
||||
linear_units: 2048
|
||||
num_blocks: 6
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
self_attention_dropout_rate: 0.0
|
||||
src_attention_dropout_rate: 0.0
|
||||
|
||||
# hybrid CTC/attention
|
||||
model_conf:
|
||||
ctc_weight: 0.3
|
||||
lsm_weight: 0.1 # label smoothing option
|
||||
length_normalized_loss: false
|
||||
init_type: 'kaiming_uniform' # !Warning: need to convergence
|
||||
|
||||
###########################################
|
||||
# Data #
|
||||
###########################################
|
||||
train_manifest: data/manifest.train
|
||||
dev_manifest: data/manifest.dev
|
||||
test_manifest: data/manifest.test
|
||||
|
||||
###########################################
|
||||
# Dataloader #
|
||||
###########################################
|
||||
vocab_filepath: data/lang_char/vocab.txt
|
||||
spm_model_prefix: 'data/lang_char/bpe_bpe_11297'
|
||||
unit_type: 'spm'
|
||||
preprocess_config: conf/preprocess.yaml
|
||||
feat_dim: 80
|
||||
stride_ms: 20.0
|
||||
window_ms: 30.0
|
||||
sortagrad: 0 # Feed samples from shortest to longest ; -1: enabled for all epochs, 0: disabled, other: enabled for 'other' epochs
|
||||
batch_size: 5
|
||||
maxlen_in: 512 # if input length > maxlen-in, batchsize is automatically reduced
|
||||
maxlen_out: 150 # if output length > maxlen-out, batchsize is automatically reduced
|
||||
minibatches: 0 # for debug
|
||||
batch_count: auto
|
||||
batch_bins: 0
|
||||
batch_frames_in: 0
|
||||
batch_frames_out: 0
|
||||
batch_frames_inout: 0
|
||||
num_workers: 2
|
||||
subsampling_factor: 1
|
||||
num_encs: 1
|
||||
|
||||
###########################################
|
||||
# Training #
|
||||
###########################################
|
||||
n_epoch: 100
|
||||
accum_grad: 4
|
||||
global_grad_clip: 5.0
|
||||
dist_sampler: False
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.002
|
||||
weight_decay: 1.0e-6
|
||||
scheduler: warmuplr
|
||||
scheduler_conf:
|
||||
warmup_steps: 25000
|
||||
lr_decay: 1.0
|
||||
log_interval: 100
|
||||
checkpoint:
|
||||
kbest_n: 50
|
||||
latest_n: 5
|
@ -0,0 +1,29 @@
|
||||
process:
|
||||
# extract kaldi fbank from PCM
|
||||
- type: fbank_kaldi
|
||||
fs: 16000
|
||||
n_mels: 80
|
||||
n_shift: 160
|
||||
win_length: 400
|
||||
dither: 1.0
|
||||
- type: cmvn_json
|
||||
cmvn_path: data/mean_std.json
|
||||
# these three processes are a.k.a. SpecAugument
|
||||
- type: time_warp
|
||||
max_time_warp: 5
|
||||
inplace: true
|
||||
mode: PIL
|
||||
- type: freq_mask
|
||||
F: 30
|
||||
n_mask: 2
|
||||
inplace: true
|
||||
replace_with_zero: false
|
||||
- type: time_mask
|
||||
T: 40
|
||||
n_mask: 2
|
||||
inplace: true
|
||||
replace_with_zero: false
|
||||
|
||||
|
||||
|
||||
|
@ -0,0 +1,12 @@
|
||||
beam_size: 10
|
||||
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
|
||||
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
|
||||
reverse_weight: 0.3 # reverse weight for attention rescoring decode mode.
|
||||
decoding_chunk_size: 16 # decoding chunk size. Defaults to -1.
|
||||
# <0: for decoding, use full chunk.
|
||||
# >0: for decoding, use fixed chunk size as set.
|
||||
# 0: used for training, it's prohibited here.
|
||||
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
|
||||
simulate_streaming: True # simulate streaming inference. Defaults to False.
|
||||
decode_batch_size: 128
|
||||
error_rate_type: cer
|
@ -0,0 +1,12 @@
|
||||
beam_size: 10
|
||||
decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring'
|
||||
ctc_weight: 0.5 # ctc weight for attention rescoring decode mode.
|
||||
#reverse_weight: 0.3 # reverse weight for attention rescoring decode mode.
|
||||
decoding_chunk_size: -1 # decoding chunk size. Defaults to -1.
|
||||
# <0: for decoding, use full chunk.
|
||||
# >0: for decoding, use fixed chunk size as set.
|
||||
# 0: used for training, it's prohibited here.
|
||||
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
|
||||
simulate_streaming: False # simulate streaming inference. Defaults to False.
|
||||
decode_batch_size: 1
|
||||
error_rate_type: cer
|
@ -0,0 +1,88 @@
|
||||
#!/bin/bash
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
dict_dir=data/lang_char
|
||||
|
||||
# bpemode (unigram or bpe)
|
||||
nbpe=11297
|
||||
bpemode=bpe
|
||||
bpeprefix="${dict_dir}/bpe_${bpemode}_${nbpe}"
|
||||
|
||||
stride_ms=20
|
||||
window_ms=30
|
||||
sample_rate=16000
|
||||
feat_dim=80
|
||||
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh
|
||||
|
||||
|
||||
mkdir -p data
|
||||
mkdir -p ${dict_dir}
|
||||
TARGET_DIR=${MAIN_ROOT}/dataset
|
||||
mkdir -p ${TARGET_DIR}
|
||||
|
||||
#prepare data
|
||||
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
|
||||
if [ ! -d "${MAIN_ROOT}/dataset/tal_cs/TALCS_corpus" ]; then
|
||||
echo "${MAIN_ROOT}/dataset/tal_cs/TALCS_corpus does not exist. Please donwload tal_cs data and unpack it from https://ai.100tal.com/dataset first."
|
||||
echo "data md5 reference: 4c879b3c9c05365fc9dee1fc68713afe"
|
||||
exit
|
||||
fi
|
||||
# create manifest json file from TALCS_corpus
|
||||
python ${MAIN_ROOT}/dataset/tal_cs/tal_cs.py --target_dir ${MAIN_ROOT}/dataset/tal_cs/TALCS_corpus/ --manifest_prefix data/
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# compute mean and stddev for normalizer
|
||||
num_workers=$(nproc)
|
||||
python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
|
||||
--manifest_path="data/manifest.train.raw" \
|
||||
--num_samples=-1 \
|
||||
--spectrum_type="fbank" \
|
||||
--feat_dim=${feat_dim} \
|
||||
--delta_delta=false \
|
||||
--sample_rate=${sample_rate} \
|
||||
--stride_ms=${stride_ms} \
|
||||
--window_ms=${window_ms} \
|
||||
--use_dB_normalization=False \
|
||||
--num_workers=${num_workers} \
|
||||
--output_path="data/mean_std.json"
|
||||
echo "compute mean and stddev done."
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
#use train_set build dict
|
||||
python3 ${MAIN_ROOT}/utils/build_vocab.py \
|
||||
--unit_type 'spm' \
|
||||
--count_threshold=0 \
|
||||
--vocab_path="${dict_dir}/vocab.txt" \
|
||||
--manifest_paths="data/manifest.train.raw" \
|
||||
--spm_mode=${bpemode} \
|
||||
--spm_vocab_size=${nbpe} \
|
||||
--spm_model_prefix=${bpeprefix} \
|
||||
--spm_character_coverage=1
|
||||
echo "build dict done."
|
||||
fi
|
||||
|
||||
#use new dict format data
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# format manifest with tokenids, vocab size
|
||||
for sub in train dev test ; do
|
||||
{
|
||||
python3 ${MAIN_ROOT}/utils/format_data.py \
|
||||
--cmvn_path "data/mean_std.json" \
|
||||
--unit_type "spm" \
|
||||
--spm_model_prefix ${bpeprefix} \
|
||||
--vocab_path="${dict_dir}/vocab.txt" \
|
||||
--manifest_path="data/manifest.${sub}.raw" \
|
||||
--output_path="data/manifest.${sub}"
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Formt mnaifest failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
}&
|
||||
done
|
||||
wait
|
||||
echo "format data done."
|
||||
fi
|
@ -0,0 +1,72 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ $# != 3 ];then
|
||||
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
echo "using $ngpu gpus..."
|
||||
|
||||
config_path=$1
|
||||
decode_config_path=$2
|
||||
ckpt_prefix=$3
|
||||
|
||||
chunk_mode=false
|
||||
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
|
||||
chunk_mode=true
|
||||
fi
|
||||
|
||||
# download language model
|
||||
#bash local/download_lm_ch.sh
|
||||
#if [ $? -ne 0 ]; then
|
||||
# exit 1
|
||||
#fi
|
||||
|
||||
|
||||
for type in attention ctc_greedy_search; do
|
||||
echo "decoding ${type}"
|
||||
if [ ${chunk_mode} == true ];then
|
||||
# stream decoding only support batchsize=1
|
||||
batch_size=1
|
||||
else
|
||||
batch_size=64
|
||||
fi
|
||||
output_dir=${ckpt_prefix}
|
||||
mkdir -p ${output_dir}
|
||||
python3 -u ${BIN_DIR}/test.py \
|
||||
--ngpu ${ngpu} \
|
||||
--config ${config_path} \
|
||||
--decode_cfg ${decode_config_path} \
|
||||
--result_file ${output_dir}/${type}.rsl \
|
||||
--checkpoint_path ${ckpt_prefix} \
|
||||
--opts decode.decoding_method ${type} \
|
||||
--opts decode.decode_batch_size ${batch_size}
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
for type in ctc_prefix_beam_search attention_rescoring; do
|
||||
echo "decoding ${type}"
|
||||
batch_size=1
|
||||
output_dir=${ckpt_prefix}
|
||||
mkdir -p ${output_dir}
|
||||
python3 -u ${BIN_DIR}/test.py \
|
||||
--ngpu ${ngpu} \
|
||||
--config ${config_path} \
|
||||
--decode_cfg ${decode_config_path} \
|
||||
--result_file ${output_dir}/${type}.rsl \
|
||||
--checkpoint_path ${ckpt_prefix} \
|
||||
--opts decode.decoding_method ${type} \
|
||||
--opts decode.decode_batch_size ${batch_size}
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
exit 0
|
@ -0,0 +1,58 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ $# != 4 ];then
|
||||
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix audio_file"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
echo "using $ngpu gpus..."
|
||||
|
||||
config_path=$1
|
||||
decode_config_path=$2
|
||||
ckpt_prefix=$3
|
||||
audio_file=$4
|
||||
|
||||
mkdir -p data
|
||||
wget -nc https://paddlespeech.bj.bcebos.com/datasets/single_wav/zh/demo_01_03.wav -P data/
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f ${audio_file} ]; then
|
||||
echo "Plase input the right audio_file path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
chunk_mode=false
|
||||
if [[ ${config_path} =~ ^.*chunk_.*yaml$ ]];then
|
||||
chunk_mode=true
|
||||
fi
|
||||
|
||||
# download language model
|
||||
#bash local/download_lm_ch.sh
|
||||
#if [ $? -ne 0 ]; then
|
||||
# exit 1
|
||||
#fi
|
||||
|
||||
for type in attention_rescoring; do
|
||||
echo "decoding ${type}"
|
||||
batch_size=1
|
||||
output_dir=${ckpt_prefix}
|
||||
mkdir -p ${output_dir}
|
||||
python3 -u ${BIN_DIR}/test_wav.py \
|
||||
--ngpu ${ngpu} \
|
||||
--config ${config_path} \
|
||||
--decode_cfg ${decode_config_path} \
|
||||
--result_file ${output_dir}/${type}.rsl \
|
||||
--checkpoint_path ${ckpt_prefix} \
|
||||
--opts decode.decoding_method ${type} \
|
||||
--opts decode.decode_batch_size ${batch_size} \
|
||||
--audio_file ${audio_file}
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
exit 0
|
@ -0,0 +1,72 @@
|
||||
#!/bin/bash
|
||||
|
||||
profiler_options=
|
||||
benchmark_batch_size=0
|
||||
benchmark_max_step=0
|
||||
|
||||
# seed may break model convergence
|
||||
seed=0
|
||||
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||
|
||||
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
|
||||
echo "using $ngpu gpus..."
|
||||
|
||||
if [ ${seed} != 0 ]; then
|
||||
export FLAGS_cudnn_deterministic=True
|
||||
echo "using seed $seed & FLAGS_cudnn_deterministic=True ..."
|
||||
fi
|
||||
|
||||
if [ $# -lt 2 ] && [ $# -gt 3 ];then
|
||||
echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name ips(optional)"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
config_path=$1
|
||||
ckpt_name=$2
|
||||
ips=$3
|
||||
|
||||
if [ ! $ips ];then
|
||||
ips_config=
|
||||
else
|
||||
ips_config="--ips="${ips}
|
||||
fi
|
||||
echo ${ips_config}
|
||||
|
||||
mkdir -p exp
|
||||
|
||||
# default memeory allocator strategy may case gpu training hang
|
||||
# for no OOM raised when memory exhaused
|
||||
export FLAGS_allocator_strategy=naive_best_fit
|
||||
|
||||
if [ ${ngpu} == 0 ]; then
|
||||
python3 -u ${BIN_DIR}/train.py \
|
||||
--ngpu ${ngpu} \
|
||||
--seed ${seed} \
|
||||
--config ${config_path} \
|
||||
--output exp/${ckpt_name} \
|
||||
--profiler-options "${profiler_options}" \
|
||||
--benchmark-batch-size ${benchmark_batch_size} \
|
||||
--benchmark-max-step ${benchmark_max_step}
|
||||
else
|
||||
python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
|
||||
--ngpu ${ngpu} \
|
||||
--seed ${seed} \
|
||||
--config ${config_path} \
|
||||
--output exp/${ckpt_name} \
|
||||
--profiler-options "${profiler_options}" \
|
||||
--benchmark-batch-size ${benchmark_batch_size} \
|
||||
--benchmark-max-step ${benchmark_max_step}
|
||||
fi
|
||||
|
||||
|
||||
if [ ${seed} != 0 ]; then
|
||||
unset FLAGS_cudnn_deterministic
|
||||
fi
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in training!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
exit 0
|
@ -0,0 +1,15 @@
|
||||
export MAIN_ROOT=`realpath ${PWD}/../../../`
|
||||
|
||||
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||
export LC_ALL=C
|
||||
|
||||
export PYTHONDONTWRITEBYTECODE=1
|
||||
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||
|
||||
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
|
||||
|
||||
# model exp
|
||||
MODEL=u2
|
||||
export BIN_DIR=${MAIN_ROOT}/paddlespeech/s2t/exps/${MODEL}/bin
|
@ -0,0 +1,51 @@
|
||||
#!/bin/bash
|
||||
source path.sh || exit 1;
|
||||
set -e
|
||||
|
||||
gpus=0,1,2,3
|
||||
stage=0
|
||||
stop_stage=50
|
||||
conf_path=conf/conformer.yaml
|
||||
ips= #xxx.xxx.xxx.xxx,xxx.xxx.xxx.xxx
|
||||
decode_conf_path=conf/tuning/decode.yaml
|
||||
average_checkpoint=true
|
||||
avg_num=10
|
||||
|
||||
. ${MAIN_ROOT}/utils/parse_options.sh || exit 1;
|
||||
|
||||
avg_ckpt=avg_${avg_num}
|
||||
ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}')
|
||||
echo "checkpoint name ${ckpt}"
|
||||
|
||||
audio_file="data/demo_01_03.wav"
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# prepare data
|
||||
bash ./local/data.sh || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# train model, all `ckpt` under `exp` dir
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} ${ips}
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# avg n best model
|
||||
avg.sh best exp/${ckpt}/checkpoints ${avg_num}
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
# test ckpt avg_n
|
||||
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
# test a single .wav file
|
||||
CUDA_VISIBLE_DEVICES=0 ./local/test_wav.sh ${conf_path} ${decode_conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1
|
||||
fi
|
||||
|
||||
# Not supported at now!!!
|
||||
if [ ${stage} -le 51 ] && [ ${stop_stage} -ge 51 ]; then
|
||||
# export ckpt avg_n
|
||||
CUDA_VISIBLE_DEVICES=0 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
|
||||
fi
|
@ -0,0 +1 @@
|
||||
../../../utils
|
Loading…
Reference in new issue