From b9790d03f2564c9d6a5361d34a4e81c26d1db4bc Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 18 Nov 2021 02:51:20 +0000 Subject: [PATCH] add wenetspeech egs --- examples/aishell/s1/run.sh | 2 +- examples/wenetspeech/asr1/conf/conformer.yaml | 113 +++++++++++++++ .../wenetspeech/asr1/conf/preprocess.yaml | 29 ++++ examples/wenetspeech/asr1/local/data.sh | 129 +++++++++++++++++ .../wenetspeech/asr1/local/extract_meta.py | 102 +++++++++++++ .../wenetspeech/asr1/local/process_opus.py | 89 ++++++++++++ examples/wenetspeech/asr1/local/test.sh | 1 + .../asr1/local/wenetspeech_data_prep.sh | 135 ++++++++++++++++++ examples/wenetspeech/asr1/path.sh | 15 ++ examples/wenetspeech/asr1/run.sh | 55 +++++++ 10 files changed, 669 insertions(+), 1 deletion(-) create mode 100644 examples/wenetspeech/asr1/conf/conformer.yaml create mode 100644 examples/wenetspeech/asr1/conf/preprocess.yaml create mode 100644 examples/wenetspeech/asr1/local/data.sh create mode 100644 examples/wenetspeech/asr1/local/extract_meta.py create mode 100644 examples/wenetspeech/asr1/local/process_opus.py create mode 100644 examples/wenetspeech/asr1/local/test.sh create mode 100644 examples/wenetspeech/asr1/local/wenetspeech_data_prep.sh create mode 100644 examples/wenetspeech/asr1/path.sh create mode 100644 examples/wenetspeech/asr1/run.sh diff --git a/examples/aishell/s1/run.sh b/examples/aishell/s1/run.sh index 126c8e4e..94c2c4df 100644 --- a/examples/aishell/s1/run.sh +++ b/examples/aishell/s1/run.sh @@ -53,5 +53,5 @@ fi if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then # test a single .wav file - CUDA_VISIBLE_DEVICES=3 ./local/test_hub.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1 + CUDA_VISIBLE_DEVICES=0 ./local/test_hub.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1 fi diff --git a/examples/wenetspeech/asr1/conf/conformer.yaml b/examples/wenetspeech/asr1/conf/conformer.yaml new file mode 100644 index 00000000..0340dc85 --- /dev/null +++ b/examples/wenetspeech/asr1/conf/conformer.yaml @@ -0,0 +1,113 @@ +# network architecture +model: + # encoder related + encoder: conformer + encoder_conf: + output_size: 512 # dimension of attention + attention_heads: 8 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: True + use_cnn_module: True + cnn_module_kernel: 15 + cnn_module_norm: layer_norm + activation_type: swish + pos_enc_layer_type: rel_pos + selfattention_layer_type: rel_selfattn + + # decoder related + decoder: transformer + decoder_conf: + attention_heads: 8 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + + # hybrid CTC/attention + model_conf: + ctc_weight: 0.3 + ctc_dropoutrate: 0.0 + ctc_grad_norm_type: null + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.train + dev_manifest: data/manifest.dev + test_manifest: data/manifest.test + min_input_len: 0.1 # second + max_input_len: 12.0 # second + min_output_len: 1.0 + max_output_len: 400.0 + min_output_input_ratio: 0.05 + max_output_input_ratio: 10.0 + +collator: + vocab_filepath: data/vocab.txt + unit_type: 'char' + spm_model_prefix: '' + augmentation_config: conf/preprocess.yaml + batch_size: 64 + raw_wav: True # use raw_wav or kaldi feature + spectrum_type: fbank #linear, mfcc, fbank + feat_dim: 80 + delta_delta: False + dither: 1.0 + target_sample_rate: 16000 + max_freq: None + n_fft: None + stride_ms: 10.0 + window_ms: 25.0 + use_dB_normalization: True + target_dB: -20 + random_seed: 0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 2 + + +training: + n_epoch: 240 + accum_grad: 16 + global_grad_clip: 5.0 + log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 + optim: adam + optim_conf: + lr: 0.001 + weight_decay: 1e-6 + scheduler: warmuplr # pytorch v1.1.0+ required + scheduler_conf: + warmup_steps: 5000 + lr_decay: 1.0 + + +decoding: + batch_size: 128 + error_rate_type: cer + decoding_method: attention # 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring' + lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm + alpha: 2.5 + beta: 0.3 + beam_size: 10 + cutoff_prob: 1.0 + cutoff_top_n: 0 + num_proc_bsearch: 8 + ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. + decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. + # <0: for decoding, use full chunk. + # >0: for decoding, use fixed chunk size as set. + # 0: used for training, it's prohibited here. + num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. + simulate_streaming: False # simulate streaming inference. Defaults to False. \ No newline at end of file diff --git a/examples/wenetspeech/asr1/conf/preprocess.yaml b/examples/wenetspeech/asr1/conf/preprocess.yaml new file mode 100644 index 00000000..dd4cfd27 --- /dev/null +++ b/examples/wenetspeech/asr1/conf/preprocess.yaml @@ -0,0 +1,29 @@ +process: + # extract kaldi fbank from PCM + - type: fbank_kaldi + fs: 16000 + n_mels: 80 + n_shift: 160 + win_length: 400 + dither: true + - type: cmvn_json + cmvn_path: data/mean_std.json + # these three processes are a.k.a. SpecAugument + - type: time_warp + max_time_warp: 5 + inplace: true + mode: PIL + - type: freq_mask + F: 30 + n_mask: 2 + inplace: true + replace_with_zero: false + - type: time_mask + T: 40 + n_mask: 2 + inplace: true + replace_with_zero: false + + + + diff --git a/examples/wenetspeech/asr1/local/data.sh b/examples/wenetspeech/asr1/local/data.sh new file mode 100644 index 00000000..67b3d5a5 --- /dev/null +++ b/examples/wenetspeech/asr1/local/data.sh @@ -0,0 +1,129 @@ +#!/bin/bash + +# Copyright 2021 Mobvoi Inc(Author: Di Wu, Binbin Zhang) +# NPU, ASLP Group (Author: Qijie Shao) + +stage=-1 +stop_stage=100 + +# Use your own data path. You need to download the WenetSpeech dataset by yourself. +wenetspeech_data_dir=./wenetspeech +# Make sure you have 1.2T for ${shards_dir} +shards_dir=./wenetspeech_shards + +#wenetspeech training set +set=L +train_set=train_`echo $set | tr 'A-Z' 'a-z'` +dev_set=dev +test_sets="test_net test_meeting" + +cmvn=true +cmvn_sampling_divisor=20 # 20 means 5% of the training data to estimate cmvn + + +. ${MAIN_ROOT}/utils/parse_options.sh || exit 1; +set -u +set -o pipefail + + +mkdir -p data +TARGET_DIR=${MAIN_ROOT}/examples/dataset +mkdir -p ${TARGET_DIR} + +if [ ${stage} -le -2 ] && [ ${stop_stage} -ge -2 ]; then + # download data + echo "Please follow https://github.com/wenet-e2e/WenetSpeech to download the data." + exit 0; +fi + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + echo "Data preparation" + local/wenetspeech_data_prep.sh \ + --train-subset $set \ + $wenetspeech_data_dir \ + data || exit 1; +fi + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + # generate manifests + python3 ${TARGET_DIR}/aishell/aishell.py \ + --manifest_prefix="data/manifest" \ + --target_dir="${TARGET_DIR}/aishell" + + if [ $? -ne 0 ]; then + echo "Prepare Aishell failed. Terminated." + exit 1 + fi + + for dataset in train dev test; do + mv data/manifest.${dataset} data/manifest.${dataset}.raw + done +fi + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # compute mean and stddev for normalizer + if $cmvn; then + full_size=`cat data/${train_set}/wav.scp | wc -l` + sampling_size=$((full_size / cmvn_sampling_divisor)) + shuf -n $sampling_size data/$train_set/wav.scp \ + > data/$train_set/wav.scp.sampled + num_workers=$(nproc) + + python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ + --manifest_path="data/manifest.train.raw" \ + --spectrum_type="fbank" \ + --feat_dim=80 \ + --delta_delta=false \ + --stride_ms=10 \ + --window_ms=25 \ + --sample_rate=16000 \ + --use_dB_normalization=False \ + --num_samples=-1 \ + --num_workers=${num_workers} \ + --output_path="data/mean_std.json" + + if [ $? -ne 0 ]; then + echo "Compute mean and stddev failed. Terminated." + exit 1 + fi + fi +fi + +dict=data/dict/lang_char.txt +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # download data, generate manifests + # build vocabulary + python3 ${MAIN_ROOT}/utils/build_vocab.py \ + --unit_type="char" \ + --count_threshold=0 \ + --vocab_path="data/vocab.txt" \ + --manifest_paths "data/manifest.train.raw" + + if [ $? -ne 0 ]; then + echo "Build vocabulary failed. Terminated." + exit 1 + fi +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # format manifest with tokenids, vocab size + for dataset in train dev test; do + { + python3 ${MAIN_ROOT}/utils/format_data.py \ + --cmvn_path "data/mean_std.json" \ + --unit_type "char" \ + --vocab_path="data/vocab.txt" \ + --manifest_path="data/manifest.${dataset}.raw" \ + --output_path="data/manifest.${dataset}" + + if [ $? -ne 0 ]; then + echo "Formt mnaifest failed. Terminated." + exit 1 + fi + } & + done + wait +fi + +echo "Aishell data preparation done." +exit 0 diff --git a/examples/wenetspeech/asr1/local/extract_meta.py b/examples/wenetspeech/asr1/local/extract_meta.py new file mode 100644 index 00000000..4de0b7d4 --- /dev/null +++ b/examples/wenetspeech/asr1/local/extract_meta.py @@ -0,0 +1,102 @@ +# Copyright 2021 Xiaomi Corporation (Author: Yongqing Wang) +# Mobvoi Inc(Author: Di Wu, Binbin Zhang) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import argparse +import json + + +def get_args(): + parser = argparse.ArgumentParser(description=""" + This script is used to process raw json dataset of WenetSpeech, + where the long wav is splitinto segments and + data of wenet format is generated. + """) + parser.add_argument('input_json', help="""Input json file of WenetSpeech""") + parser.add_argument('output_dir', help="""Output dir for prepared data""") + + args = parser.parse_args() + return args + + +def meta_analysis(input_json, output_dir): + input_dir = os.path.dirname(input_json) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + try: + with open(input_json, 'r') as injson: + json_data = json.load(injson) + except Exception: + sys.exit(f'Failed to load input json file: {input_json}') + else: + if json_data['audios'] is not None: + with open(f'{output_dir}/text', 'w') as utt2text, \ + open(f'{output_dir}/segments', 'w') as segments, \ + open(f'{output_dir}/utt2dur', 'w') as utt2dur, \ + open(f'{output_dir}/wav.scp', 'w') as wavscp, \ + open(f'{output_dir}/utt2subsets', 'w') as utt2subsets, \ + open(f'{output_dir}/reco2dur', 'w') as reco2dur: + for long_audio in json_data['audios']: + try: + long_audio_path = os.path.realpath( + os.path.join(input_dir, long_audio['path'])) + aid = long_audio['aid'] + segments_lists = long_audio['segments'] + duration = long_audio['duration'] + assert (os.path.exists(long_audio_path)) + except AssertionError: + print(f'''Warning: {aid} something is wrong, + maybe AssertionError, skipped''') + continue + except Exception: + print(f'''Warning: {aid} something is wrong, maybe the + error path: {long_audio_path}, skipped''') + continue + else: + wavscp.write(f'{aid}\t{long_audio_path}\n') + reco2dur.write(f'{aid}\t{duration}\n') + for segment_file in segments_lists: + try: + sid = segment_file['sid'] + start_time = segment_file['begin_time'] + end_time = segment_file['end_time'] + dur = end_time - start_time + text = segment_file['text'] + segment_subsets = segment_file["subsets"] + except Exception: + print(f'''Warning: {segment_file} something + is wrong, skipped''') + continue + else: + utt2text.write(f'{sid}\t{text}\n') + segments.write( + f'{sid}\t{aid}\t{start_time}\t{end_time}\n' + ) + utt2dur.write(f'{sid}\t{dur}\n') + segment_sub_names = " ".join(segment_subsets) + utt2subsets.write( + f'{sid}\t{segment_sub_names}\n') + +def main(): + args = get_args() + + meta_analysis(args.input_json, args.output_dir) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/examples/wenetspeech/asr1/local/process_opus.py b/examples/wenetspeech/asr1/local/process_opus.py new file mode 100644 index 00000000..603e0082 --- /dev/null +++ b/examples/wenetspeech/asr1/local/process_opus.py @@ -0,0 +1,89 @@ +# Copyright 2021 NPU, ASLP Group (Author: Qijie Shao) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# process_opus.py: segmentation and downsampling of opus audio + +# usage: python3 process_opus.py wav.scp segments output_wav.scp + +from pydub import AudioSegment +import sys +import os + + +def read_file(wav_scp, segments): + wav_scp_dict = {} + with open(wav_scp, 'r', encoding='UTF-8') as fin: + for line_str in fin: + wav_id, path = line_str.strip().split() + wav_scp_dict[wav_id] = path + + utt_list = [] + seg_path_list = [] + start_time_list = [] + end_time_list = [] + with open(segments, 'r', encoding='UTF-8') as fin: + for line_str in fin: + arr = line_str.strip().split() + assert len(arr) == 4 + utt_list.append(arr[0]) + seg_path_list.append(wav_scp_dict[arr[1]]) + start_time_list.append(float(arr[2])) + end_time_list.append(float(arr[3])) + return utt_list, seg_path_list, start_time_list, end_time_list + + +# TODO(Qijie): Fix the process logic +def output(output_wav_scp, utt_list, seg_path_list, start_time_list, + end_time_list): + num_utts = len(utt_list) + step = int(num_utts * 0.01) + with open(output_wav_scp, 'w', encoding='UTF-8') as fout: + previous_wav_path = "" + for i in range(num_utts): + utt_id = utt_list[i] + current_wav_path = seg_path_list[i] + output_dir = (os.path.dirname(current_wav_path)) \ + .replace("audio", 'audio_seg') + seg_wav_path = os.path.join(output_dir, utt_id + '.wav') + + # if not os.path.exists(output_dir): + # os.makedirs(output_dir) + + if current_wav_path != previous_wav_path: + source_wav = AudioSegment.from_file(current_wav_path) + previous_wav_path = current_wav_path + + start = int(start_time_list[i] * 1000) + end = int(end_time_list[i] * 1000) + target_audio = source_wav[start:end].set_frame_rate(16000) + target_audio.export(seg_wav_path, format="wav") + + fout.write("{} {}\n".format(utt_id, seg_wav_path)) + if i % step == 0: + print("seg wav finished: {}%".format(int(i / step))) + + +def main(): + wav_scp = sys.argv[1] + segments = sys.argv[2] + output_wav_scp = sys.argv[3] + + utt_list, seg_path_list, start_time_list, end_time_list \ + = read_file(wav_scp, segments) + output(output_wav_scp, utt_list, seg_path_list, start_time_list, + end_time_list) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/examples/wenetspeech/asr1/local/test.sh b/examples/wenetspeech/asr1/local/test.sh new file mode 100644 index 00000000..e7c64346 --- /dev/null +++ b/examples/wenetspeech/asr1/local/test.sh @@ -0,0 +1 @@ +decode_modes="attention_rescoring ctc_greedy_search" \ No newline at end of file diff --git a/examples/wenetspeech/asr1/local/wenetspeech_data_prep.sh b/examples/wenetspeech/asr1/local/wenetspeech_data_prep.sh new file mode 100644 index 00000000..85853053 --- /dev/null +++ b/examples/wenetspeech/asr1/local/wenetspeech_data_prep.sh @@ -0,0 +1,135 @@ +#!/usr/bin/env bash + +# Copyright 2021 Xiaomi Corporation (Author: Yongqing Wang) +# Seasalt AI, Inc (Author: Guoguo Chen) +# Mobvoi Inc(Author: Di Wu, Binbin Zhang) +# NPU, ASLP Group (Author: Qijie Shao) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e +set -o pipefail + +stage=1 +prefix= +train_subset=L + +. ./tools/parse_options.sh || exit 1; + +filter_by_id () { + idlist=$1 + input=$2 + output=$3 + field=1 + if [ $# -eq 4 ]; then + field=$4 + fi + cat $input | perl -se ' + open(F, "<$idlist") || die "Could not open id-list file $idlist"; + while() { + @A = split; + @A>=1 || die "Invalid id-list file line $_"; + $seen{$A[0]} = 1; + } + while(<>) { + @A = split; + @A > 0 || die "Invalid file line $_"; + @A >= $field || die "Invalid file line $_"; + if ($seen{$A[$field-1]}) { + print $_; + } + }' -- -idlist="$idlist" -field="$field" > $output ||\ + (echo "$0: filter_by_id() error: $input" && exit 1) || exit 1; +} + +subset_data_dir () { + utt_list=$1 + src_dir=$2 + dest_dir=$3 + mkdir -p $dest_dir || exit 1; + # wav.scp text segments utt2dur + filter_by_id $utt_list $src_dir/utt2dur $dest_dir/utt2dur ||\ + (echo "$0: subset_data_dir() error: $src_dir/utt2dur" && exit 1) || exit 1; + filter_by_id $utt_list $src_dir/text $dest_dir/text ||\ + (echo "$0: subset_data_dir() error: $src_dir/text" && exit 1) || exit 1; + filter_by_id $utt_list $src_dir/segments $dest_dir/segments ||\ + (echo "$0: subset_data_dir() error: $src_dir/segments" && exit 1) || exit 1; + awk '{print $2}' $dest_dir/segments | sort | uniq > $dest_dir/reco + filter_by_id $dest_dir/reco $src_dir/wav.scp $dest_dir/wav.scp ||\ + (echo "$0: subset_data_dir() error: $src_dir/wav.scp" && exit 1) || exit 1; + rm -f $dest_dir/reco +} + +if [ $# -ne 2 ]; then + echo "Usage: $0 [options] " + echo " e.g.: $0 --train-subset L /disk1/audio_data/wenetspeech/ data/" + echo "" + echo "This script takes the WenetSpeech source directory, and prepares the" + echo "WeNet format data directory." + echo " --prefix # Prefix for output data directory." + echo " --stage # Processing stage." + echo " --train-subset # Train subset to be created." + exit 1 +fi + +wenetspeech_dir=$1 +data_dir=$2 + +declare -A subsets +subsets=( + [L]="train_l" + [M]="train_m" + [S]="train_s" + [W]="train_w" + [DEV]="dev" + [TEST_NET]="test_net" + [TEST_MEETING]="test_meeting") + +prefix=${prefix:+${prefix}_} + +corpus_dir=$data_dir/${prefix}corpus/ +if [ $stage -le 1 ]; then + echo "$0: Extract meta into $corpus_dir" + # Sanity check. + [ ! -f $wenetspeech_dir/WenetSpeech.json ] &&\ + echo "$0: Please download $wenetspeech_dir/WenetSpeech.json!" && exit 1; + [ ! -d $wenetspeech_dir/audio ] &&\ + echo "$0: Please download $wenetspeech_dir/audio!" && exit 1; + + [ ! -d $corpus_dir ] && mkdir -p $corpus_dir + + # Files to be created: + # wav.scp text segments utt2dur + python3 local/extract_meta.py \ + $wenetspeech_dir/WenetSpeech.json $corpus_dir || exit 1; +fi + +if [ $stage -le 2 ]; then + echo "$0: Split data to train, dev, test_net, and test_meeting" + [ ! -f $corpus_dir/utt2subsets ] &&\ + echo "$0: No such file $corpus_dir/utt2subsets!" && exit 1; + for label in $train_subset DEV TEST_NET TEST_MEETING; do + if [ ! ${subsets[$label]+set} ]; then + echo "$0: Subset $label is not defined in WenetSpeech.json." && exit 1; + fi + subset=${subsets[$label]} + [ ! -d $data_dir/${prefix}$subset ] && mkdir -p $data_dir/${prefix}$subset + cat $corpus_dir/utt2subsets | \ + awk -v s=$label '{for (i=2;i<=NF;i++) if($i==s) print $0;}' \ + > $corpus_dir/${prefix}${subset}_utt_list|| exit 1; + subset_data_dir $corpus_dir/${prefix}${subset}_utt_list \ + $corpus_dir $data_dir/${prefix}$subset || exit 1; + done +fi + +echo "$0: Done" \ No newline at end of file diff --git a/examples/wenetspeech/asr1/path.sh b/examples/wenetspeech/asr1/path.sh new file mode 100644 index 00000000..666b29bc --- /dev/null +++ b/examples/wenetspeech/asr1/path.sh @@ -0,0 +1,15 @@ +export MAIN_ROOT=`realpath ${PWD}/../../../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +export PYTHONDONTWRITEBYTECODE=1 +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ + +# model exp +MODEL=u2 +export BIN_DIR=${MAIN_ROOT}/paddlespeech/s2t/exps/${MODEL}/bin diff --git a/examples/wenetspeech/asr1/run.sh b/examples/wenetspeech/asr1/run.sh new file mode 100644 index 00000000..8c4a12cb --- /dev/null +++ b/examples/wenetspeech/asr1/run.sh @@ -0,0 +1,55 @@ +#!/bin/bash + +. path.sh || exit 1; +set -e + +gpus=0,1,2,3,4,5,6,7 +stage=0 +stop_stage=100 +conf_path=conf/conformer.yaml + +average_checkpoint=true +avg_num=10 + +. ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +avg_ckpt=avg_${avg_num} +ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') +echo "checkpoint name ${ckpt}" + +audio_file="data/tmp.wav" + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + bash ./local/data.sh || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `exp` dir + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${ckpt} +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # avg n best model + avg.sh best exp/${ckpt}/checkpoints ${avg_num} +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # test ckpt avg_n + CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 +fi + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # ctc alignment of test data + CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 +fi + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + # export ckpt avg_n + CUDA_VISIBLE_DEVICES=0 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit +fi + +if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then + # test a single .wav file + CUDA_VISIBLE_DEVICES=0 ./local/test_hub.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} ${audio_file} || exit -1 +fi