From 4ccb84831cd2a61fb2d8b5225c7abb90e7cb7b70 Mon Sep 17 00:00:00 2001 From: Junkun Date: Wed, 4 Aug 2021 13:45:44 -0700 Subject: [PATCH] script for TED-En-Zh translation --- examples/dataset/ted_en_zh/.gitignore | 6 + examples/dataset/ted_en_zh/ted_en_zh.py | 114 ++++++++++++++++++ examples/ted_en_zh/conf/transformer.yaml | 109 +++++++++++++++++ .../conf/transformer_joint_noam.yaml | 111 +++++++++++++++++ examples/ted_en_zh/local/data.sh | 111 +++++++++++++++++ examples/ted_en_zh/local/test.sh | 35 ++++++ examples/ted_en_zh/local/train.sh | 33 +++++ examples/ted_en_zh/path.sh | 14 +++ examples/ted_en_zh/run.sh | 40 ++++++ utils/build_vocab.py | 13 +- utils/format_triplet_data.py | 96 +++++++++++++++ 11 files changed, 679 insertions(+), 3 deletions(-) create mode 100644 examples/dataset/ted_en_zh/.gitignore create mode 100644 examples/dataset/ted_en_zh/ted_en_zh.py create mode 100644 examples/ted_en_zh/conf/transformer.yaml create mode 100644 examples/ted_en_zh/conf/transformer_joint_noam.yaml create mode 100755 examples/ted_en_zh/local/data.sh create mode 100755 examples/ted_en_zh/local/test.sh create mode 100755 examples/ted_en_zh/local/train.sh create mode 100644 examples/ted_en_zh/path.sh create mode 100755 examples/ted_en_zh/run.sh create mode 100755 utils/format_triplet_data.py diff --git a/examples/dataset/ted_en_zh/.gitignore b/examples/dataset/ted_en_zh/.gitignore new file mode 100644 index 000000000..ad6ab64af --- /dev/null +++ b/examples/dataset/ted_en_zh/.gitignore @@ -0,0 +1,6 @@ +*.tar.gz.* +manifest.* +*.md +EN-ZH/ +train-split/ +test-segment/ \ No newline at end of file diff --git a/examples/dataset/ted_en_zh/ted_en_zh.py b/examples/dataset/ted_en_zh/ted_en_zh.py new file mode 100644 index 000000000..08f15119e --- /dev/null +++ b/examples/dataset/ted_en_zh/ted_en_zh.py @@ -0,0 +1,114 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Prepare Ted-En-Zh speech translation dataset + +Create manifest files from splited datased. +dev set: tst2010, test set: tst2015 +Manifest file is a json-format file with each line containing the +meta data (i.e. audio filepath, transcript and audio duration) +of each audio file in the data set. +""" +import argparse +import codecs +import json +import os + +import soundfile + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument( + "--src_dir", + default="", + type=str, + help="Directory to kaldi splited data. (default: %(default)s)") +parser.add_argument( + "--manifest_prefix", + default="manifest", + type=str, + help="Filepath prefix for output manifests. (default: %(default)s)") +args = parser.parse_args() + + +def create_manifest(data_dir, manifest_path_prefix): + print("Creating manifest %s ..." % manifest_path_prefix) + json_lines = [] + + data_types_infos = [('train', 'train-split/train-segment', 'En-Zh/train.en-zh'), + ('dev', 'test-segment/tst2010', 'En-Zh/tst2010.en-zh'), + ('test', 'test-segment/tst2015', 'En-Zh/tst2015.en-zh')] + for data_info in data_types_infos: + dtype, audio_relative_dir, text_relative_path = data_info + del json_lines[:] + total_sec = 0.0 + total_text = 0.0 + total_num = 0 + + text_path = os.path.join(data_dir, text_relative_path) + audio_dir = os.path.join(data_dir, audio_relative_dir) + + for line in codecs.open(text_path, 'r', 'utf-8', errors='ignore'): + line = line.strip() + if len(line) < 1: + continue + audio_id, trancription, translation = line.split('\t') + utt = audio_id.split('.')[0] + + audio_path = os.path.join(audio_dir, audio_id) + if os.path.exists(audio_path): + if os.path.getsize(audio_path) < 30000: + continue + audio_data, samplerate = soundfile.read(audio_path) + duration = float(len(audio_data) / samplerate) + json_lines.append( + json.dumps( + { + 'utt': utt, + 'feat': audio_path, + 'feat_shape': (duration, ), # second + 'text': " ".join(translation.split()), + 'text1': " ".join(trancription.split()) + }, + ensure_ascii=False)) + + total_sec += duration + total_text += len(translation.split()) + total_num += 1 + if not total_num % 1000: + print(dtype, 'Processed:', total_num) + + manifest_path = manifest_path_prefix + '.' + dtype + '.raw' + with codecs.open(manifest_path, 'w', 'utf-8') as fout: + for line in json_lines: + fout.write(line + '\n') + + +def prepare_dataset(src_dir, manifest_path=None): + """create manifest file.""" + if os.path.isdir(manifest_path): + manifest_path = os.path.join(manifest_path, 'manifest') + if manifest_path: + create_manifest(src_dir, manifest_path) + + +def main(): + if args.src_dir.startswith('~'): + args.src_dir = os.path.expanduser(args.src_dir) + + prepare_dataset(src_dir=args.src_dir, manifest_path=args.manifest_prefix) + + print("manifest prepare done!") + + +if __name__ == '__main__': + main() diff --git a/examples/ted_en_zh/conf/transformer.yaml b/examples/ted_en_zh/conf/transformer.yaml new file mode 100644 index 000000000..10a3e7f57 --- /dev/null +++ b/examples/ted_en_zh/conf/transformer.yaml @@ -0,0 +1,109 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.train.tiny + dev_manifest: data/manifest.dev + test_manifest: data/manifest.test + min_input_len: 0.5 # second + max_input_len: 3000.0 # second + min_output_len: 0.0 # tokens + max_output_len: 400.0 # tokens + min_output_input_ratio: 0.01 + max_output_input_ratio: 20.0 + +collator: + vocab_filepath: data/vocab.txt + unit_type: 'spm' + spm_model_prefix: data/bpe_unigram_8000 + mean_std_filepath: "" + # augmentation_config: conf/augmentation.json + batch_size: 10 + raw_wav: True # use raw_wav or kaldi feature + specgram_type: fbank #linear, mfcc, fbank + feat_dim: 80 + delta_delta: False + dither: 1.0 + target_sample_rate: 16000 + max_freq: None + n_fft: None + stride_ms: 10.0 + window_ms: 25.0 + use_dB_normalization: True + target_dB: -20 + random_seed: 0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 2 + + +# network architecture +model: + cmvn_file: "data/mean_std.json" + cmvn_file_type: "json" + # encoder related + encoder: transformer + encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true + + # decoder related + decoder: transformer + decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + + # hybrid CTC/attention + model_conf: + asr_weight: 0.0 + ctc_weight: 0.0 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + + +training: + n_epoch: 120 + accum_grad: 2 + global_grad_clip: 5.0 + optim: adam + optim_conf: + lr: 0.004 + weight_decay: 1e-06 + scheduler: warmuplr # pytorch v1.1.0+ required + scheduler_conf: + warmup_steps: 25000 + lr_decay: 1.0 + log_interval: 5 + checkpoint: + kbest_n: 50 + latest_n: 5 + + +decoding: + batch_size: 5 + error_rate_type: char-bleu + decoding_method: fullsentence # 'fullsentence', 'simultaneous' + alpha: 2.5 + beta: 0.3 + beam_size: 10 + cutoff_prob: 1.0 + cutoff_top_n: 0 + num_proc_bsearch: 8 + ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. + decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. + # <0: for decoding, use full chunk. + # >0: for decoding, use fixed chunk size as set. + # 0: used for training, it's prohibited here. + num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. + simulate_streaming: False # simulate streaming inference. Defaults to False. diff --git a/examples/ted_en_zh/conf/transformer_joint_noam.yaml b/examples/ted_en_zh/conf/transformer_joint_noam.yaml new file mode 100644 index 000000000..ba384f8cd --- /dev/null +++ b/examples/ted_en_zh/conf/transformer_joint_noam.yaml @@ -0,0 +1,111 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.train + dev_manifest: data/manifest.dev + test_manifest: data/manifest.test + min_input_len: 0.5 # second + max_input_len: 3000.0 # second + min_output_len: 0.0 # tokens + max_output_len: 400.0 # tokens + min_output_input_ratio: 0.01 + max_output_input_ratio: 20.0 + +collator: + vocab_filepath: data/vocab.txt + unit_type: 'spm' + spm_model_prefix: data/bpe_unigram_8000 + mean_std_filepath: "" + # augmentation_config: conf/augmentation.json + batch_size: 10 + raw_wav: True # use raw_wav or kaldi feature + specgram_type: fbank #linear, mfcc, fbank + feat_dim: 80 + delta_delta: False + dither: 1.0 + target_sample_rate: 16000 + max_freq: None + n_fft: None + stride_ms: 10.0 + window_ms: 25.0 + use_dB_normalization: True + target_dB: -20 + random_seed: 0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 2 + + +# network architecture +model: + cmvn_file: "data/mean_std.json" + cmvn_file_type: "json" + # encoder related + encoder: transformer + encoder_conf: + output_size: 256 # dimension of attention + attention_heads: 4 + linear_units: 2048 # the number of units of position-wise feed forward + num_blocks: 12 # the number of encoder blocks + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.0 + input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8 + normalize_before: true + + # decoder related + decoder: transformer + decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + + # hybrid CTC/attention + model_conf: + asr_weight: 0.5 + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + + +training: + n_epoch: 120 + accum_grad: 2 + global_grad_clip: 5.0 + optim: adam + optim_conf: + lr: 2.5 + weight_decay: 1e-06 + scheduler: noam + scheduler_conf: + warmup_steps: 25000 + lr_decay: 1.0 + log_interval: 5 + checkpoint: + kbest_n: 50 + latest_n: 5 + + +decoding: + batch_size: 5 + error_rate_type: char-bleu + decoding_method: fullsentence # 'fullsentence', 'simultaneous' + alpha: 2.5 + beta: 0.3 + beam_size: 10 + cutoff_prob: 1.0 + cutoff_top_n: 0 + num_proc_bsearch: 8 + ctc_weight: 0.5 # ctc weight for attention rescoring decode mode. + decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. + # <0: for decoding, use full chunk. + # >0: for decoding, use fixed chunk size as set. + # 0: used for training, it's prohibited here. + num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. + simulate_streaming: False # simulate streaming inference. Defaults to False. + + diff --git a/examples/ted_en_zh/local/data.sh b/examples/ted_en_zh/local/data.sh new file mode 100755 index 000000000..0a5c58aa5 --- /dev/null +++ b/examples/ted_en_zh/local/data.sh @@ -0,0 +1,111 @@ +#!/bin/bash + +stage=-1 +stop_stage=100 + +# bpemode (unigram or bpe) +nbpe=8000 +bpemode=unigram +bpeprefix="data/bpe_${bpemode}_${nbpe}" +DATA_DIR= + + +source ${MAIN_ROOT}/utils/parse_options.sh + + +mkdir -p data +TARGET_DIR=${MAIN_ROOT}/examples/dataset +mkdir -p ${TARGET_DIR} + +if [ ! -d ${SOURCE_DIR} ]; then + echo "Error: Dataset is not avaiable. Please download and unzip the dataset" + echo "Download Link: https://pan.baidu.com/s/18L-59wgeS96WkObISrytQQ Passwd: bva0" + echo "The tree of the directory should be:" + echo "." + echo "|-- En-Zh" + echo "|-- test-segment" + echo " |-- tst2010" + echo " |-- ..." + echo "|-- train-split" + echo " |-- train-segment" + echo "|-- README.md" + + exit 1 +fi + +if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then + # generate manifests + python3 ${TARGET_DIR}/ted_en_zh/ted_en_zh.py \ + --manifest_prefix="data/manifest" \ + --src_dir="${DATA_DIR}" + + echo "Complete raw data pre-process." +fi + + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # build vocabulary + python3 ${MAIN_ROOT}/utils/build_vocab.py \ + --unit_type "spm" \ + --spm_vocab_size=${nbpe} \ + --spm_mode ${bpemode} \ + --spm_model_prefix ${bpeprefix} \ + --vocab_path="data/vocab.txt" \ + --text_keys 'text' 'text1' \ + --manifest_paths="data/manifest.train.raw" + + + if [ $? -ne 0 ]; then + echo "Build vocabulary failed. Terminated." + exit 1 + fi +fi + + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # compute mean and stddev for normalizer + num_workers=$(nproc) + python3 ${MAIN_ROOT}/utils/compute_mean_std.py \ + --manifest_path="data/manifest.train.raw" \ + --num_samples=-1 \ + --specgram_type="fbank" \ + --feat_dim=80 \ + --delta_delta=false \ + --sample_rate=16000 \ + --stride_ms=10.0 \ + --window_ms=25.0 \ + --use_dB_normalization=False \ + --num_workers=${num_workers} \ + --output_path="data/mean_std.json" + + if [ $? -ne 0 ]; then + echo "Compute mean and stddev failed. Terminated." + exit 1 + fi +fi + + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # format manifest with tokenids, vocab size + for set in train dev test; do + { + python3 ${MAIN_ROOT}/utils/format_triplet_data.py \ + --feat_type "raw" \ + --cmvn_path "data/mean_std.json" \ + --unit_type "spm" \ + --spm_model_prefix ${bpeprefix} \ + --vocab_path="data/vocab.txt" \ + --manifest_path="data/manifest.${set}.raw" \ + --output_path="data/manifest.${set}" + + if [ $? -ne 0 ]; then + echo "Formt mnaifest failed. Terminated." + exit 1 + fi + }& + done + wait +fi + +echo "Ted En-Zh Data preparation done." +exit 0 diff --git a/examples/ted_en_zh/local/test.sh b/examples/ted_en_zh/local/test.sh new file mode 100755 index 000000000..802bb13c2 --- /dev/null +++ b/examples/ted_en_zh/local/test.sh @@ -0,0 +1,35 @@ +#! /usr/bin/env bash + +if [ $# != 2 ];then + echo "usage: ${0} config_path ckpt_path_prefix" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +device=gpu +if [ ngpu == 0 ];then + device=cpu +fi +config_path=$1 +ckpt_prefix=$2 + +for type in fullsentence; do + echo "decoding ${type}" + batch_size=32 + python3 -u ${BIN_DIR}/test.py \ + --device ${device} \ + --nproc 1 \ + --config ${config_path} \ + --result_file ${ckpt_prefix}.${type}.rsl \ + --checkpoint_path ${ckpt_prefix} \ + --opts decoding.decoding_method ${type} decoding.batch_size ${batch_size} + + if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 + fi +done + +exit 0 diff --git a/examples/ted_en_zh/local/train.sh b/examples/ted_en_zh/local/train.sh new file mode 100755 index 000000000..f3eb98daf --- /dev/null +++ b/examples/ted_en_zh/local/train.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +if [ $# != 2 ];then + echo "usage: CUDA_VISIBLE_DEVICES=0 ${0} config_path ckpt_name" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +config_path=$1 +ckpt_name=$2 + +device=gpu +if [ ${ngpu} == 0 ];then + device=cpu +fi +echo "using ${device}..." + +mkdir -p exp + +python3 -u ${BIN_DIR}/train.py \ +--device ${device} \ +--nproc ${ngpu} \ +--config ${config_path} \ +--output exp/${ckpt_name} + +if [ $? -ne 0 ]; then + echo "Failed in training!" + exit 1 +fi + +exit 0 diff --git a/examples/ted_en_zh/path.sh b/examples/ted_en_zh/path.sh new file mode 100644 index 000000000..881a5b918 --- /dev/null +++ b/examples/ted_en_zh/path.sh @@ -0,0 +1,14 @@ +export MAIN_ROOT=${PWD}/../../ + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ + + +MODEL=u2_st +export BIN_DIR=${MAIN_ROOT}/deepspeech/exps/${MODEL}/bin diff --git a/examples/ted_en_zh/run.sh b/examples/ted_en_zh/run.sh new file mode 100755 index 000000000..89048f3dd --- /dev/null +++ b/examples/ted_en_zh/run.sh @@ -0,0 +1,40 @@ +#!/bin/bash +set -e +source path.sh + +stage=0 +stop_stage=100 +conf_path=conf/transformer_joint_noam.yaml +avg_num=5 +data_path=./TED-En-Zh # path to unzipped data +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; + +avg_ckpt=avg_${avg_num} +ckpt=$(basename ${conf_path} | awk -F'.' '{print $1}') +echo "checkpoint name ${ckpt}" + + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + bash ./local/data.sh --DATA_DIR ${data_path} || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `exp` dir + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./local/train.sh ${conf_path} ${ckpt} +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # avg n best model + ../../utils/avg.sh exp/${ckpt}/checkpoints ${avg_num} +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # test ckpt avg_n + CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 +fi + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # export ckpt avg_n + CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit +fi diff --git a/utils/build_vocab.py b/utils/build_vocab.py index 76092b25d..151d52f82 100755 --- a/utils/build_vocab.py +++ b/utils/build_vocab.py @@ -44,6 +44,11 @@ add_arg('manifest_paths', str, "You can provide multiple manifest files.", nargs='+', required=True) +add_arg('text_keys', str, + 'text', + "keys of the text in manifest for building vocabulary. " + "You can provide multiple k.", + nargs='+') # bpe add_arg('spm_vocab_size', int, 0, "Vocab size for spm.") add_arg('spm_mode', str, 'unigram', "spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm") @@ -58,10 +63,10 @@ def count_manifest(counter, text_feature, manifest_path): line = text_feature.tokenize(line_json['text']) counter.update(line) -def dump_text_manifest(fileobj, manifest_path): +def dump_text_manifest(fileobj, manifest_path, key='text'): manifest_jsons = read_manifest(manifest_path) for line_json in manifest_jsons: - fileobj.write(line_json['text'] + "\n") + fileobj.write(line_json[key] + "\n") def main(): print_arguments(args, globals()) @@ -78,7 +83,9 @@ def main(): fp = tempfile.NamedTemporaryFile(mode='w', delete=False) for manifest_path in args.manifest_paths: - dump_text_manifest(fp, manifest_path) + text_keys = [args.text_keys] if type(args.text_keys) is not list else args.text_keys + for text_key in text_keys: + dump_text_manifest(fp, manifest_path, key=text_key) fp.close() # train spm.SentencePieceTrainer.Train( diff --git a/utils/format_triplet_data.py b/utils/format_triplet_data.py new file mode 100755 index 000000000..f3dd7ca4a --- /dev/null +++ b/utils/format_triplet_data.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""format manifest with more metadata.""" +import argparse +import functools +import json + +from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer +from deepspeech.frontend.utility import load_cmvn +from deepspeech.frontend.utility import read_manifest +from deepspeech.utils.utility import add_arguments +from deepspeech.utils.utility import print_arguments + +parser = argparse.ArgumentParser(description=__doc__) +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('feat_type', str, "raw", "speech feature type, e.g. raw(wav, flac), kaldi") +add_arg('cmvn_path', str, + 'examples/librispeech/data/mean_std.json', + "Filepath of cmvn.") +add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm") +add_arg('vocab_path', str, + 'examples/librispeech/data/vocab.txt', + "Filepath of the vocabulary.") +add_arg('manifest_paths', str, + None, + "Filepaths of manifests for building vocabulary. " + "You can provide multiple manifest files.", + nargs='+', + required=True) +# bpe +add_arg('spm_model_prefix', str, None, + "spm model prefix, spm_model_%(bpe_mode)_%(count_threshold), only need when `unit_type` is spm") +add_arg('output_path', str, None, "filepath of formated manifest.", required=True) +# yapf: disable +args = parser.parse_args() + + +def main(): + print_arguments(args, globals()) + fout = open(args.output_path, 'w', encoding='utf-8') + + # get feat dim + mean, std = load_cmvn(args.cmvn_path, filetype='json') + feat_dim = mean.shape[0] #(D) + print(f"Feature dim: {feat_dim}") + + text_feature = TextFeaturizer(args.unit_type, args.vocab_path, args.spm_model_prefix) + vocab_size = text_feature.vocab_size + print(f"Vocab size: {vocab_size}") + + count = 0 + for manifest_path in args.manifest_paths: + manifest_jsons = read_manifest(manifest_path) + for line_json in manifest_jsons: + # text: translation text, text1: transcript text. + # Currently only support joint-vocab, will add separate vocabs setting. + line = line_json['text'] + tokens = text_feature.tokenize(line) + tokenids = text_feature.featurize(line) + line_json['token'] = tokens + line_json['token_id'] = tokenids + line_json['token_shape'] = (len(tokenids), vocab_size) + line = line_json['text1'] + tokens = text_feature.tokenize(line) + tokenids = text_feature.featurize(line) + line_json['token1'] = tokens + line_json['token_id1'] = tokenids + line_json['token_shape1'] = (len(tokenids), vocab_size) + feat_shape = line_json['feat_shape'] + assert isinstance(feat_shape, (list, tuple)), type(feat_shape) + if args.feat_type == 'raw': + feat_shape.append(feat_dim) + else: # kaldi + raise NotImplementedError('no support kaldi feat now!') + fout.write(json.dumps(line_json) + '\n') + count += 1 + + print(f"Examples number: {count}") + fout.close() + + +if __name__ == '__main__': + main()