commit
f7032c8256
@ -0,0 +1,17 @@
|
||||
"""Set up paths for DS2"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os.path
|
||||
import sys
|
||||
|
||||
|
||||
def add_path(path):
|
||||
if path not in sys.path:
|
||||
sys.path.insert(0, path)
|
||||
|
||||
|
||||
this_dir = os.path.dirname(__file__)
|
||||
proj_path = os.path.join(this_dir, '..')
|
||||
add_path(proj_path)
|
@ -0,0 +1,29 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
TRAIN_MANIFEST="cloud/cloud_manifests/cloud.manifest.train"
|
||||
DEV_MANIFEST="cloud/cloud_manifests/cloud.manifest.dev"
|
||||
CLOUD_MODEL_DIR="./checkpoints"
|
||||
BATCH_SIZE=512
|
||||
NUM_GPU=8
|
||||
NUM_NODE=1
|
||||
IS_LOCAL="True"
|
||||
|
||||
JOB_NAME=deepspeech-`date +%Y%m%d%H%M%S`
|
||||
DS2_PATH=${PWD%/*}
|
||||
cp -f pcloud_train.sh ${DS2_PATH}
|
||||
|
||||
paddlecloud submit \
|
||||
-image bootstrapper:5000/paddlepaddle/pcloud_ds2:latest \
|
||||
-jobname ${JOB_NAME} \
|
||||
-cpu ${NUM_GPU} \
|
||||
-gpu ${NUM_GPU} \
|
||||
-memory 64Gi \
|
||||
-parallelism ${NUM_NODE} \
|
||||
-pscpu 1 \
|
||||
-pservers 1 \
|
||||
-psmemory 64Gi \
|
||||
-passes 1 \
|
||||
-entry "sh pcloud_train.sh ${TRAIN_MANIFEST} ${DEV_MANIFEST} ${CLOUD_MODEL_DIR} ${NUM_GPU} ${BATCH_SIZE} ${IS_LOCAL}" \
|
||||
${DS2_PATH}
|
||||
|
||||
rm ${DS2_PATH}/pcloud_train.sh
|
@ -0,0 +1,46 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
TRAIN_MANIFEST=$1
|
||||
DEV_MANIFEST=$2
|
||||
MODEL_PATH=$3
|
||||
NUM_GPU=$4
|
||||
BATCH_SIZE=$5
|
||||
IS_LOCAL=$6
|
||||
|
||||
python ./cloud/split_data.py \
|
||||
--in_manifest_path=${TRAIN_MANIFEST} \
|
||||
--out_manifest_path='/local.manifest.train'
|
||||
|
||||
python ./cloud/split_data.py \
|
||||
--in_manifest_path=${DEV_MANIFEST} \
|
||||
--out_manifest_path='/local.manifest.dev'
|
||||
|
||||
mkdir ./logs
|
||||
|
||||
python -u train.py \
|
||||
--batch_size=${BATCH_SIZE} \
|
||||
--trainer_count=${NUM_GPU} \
|
||||
--num_passes=200 \
|
||||
--num_proc_data=${NUM_GPU} \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--num_iter_print=100 \
|
||||
--learning_rate=5e-4 \
|
||||
--max_duration=27.0 \
|
||||
--min_duration=0.0 \
|
||||
--use_sortagrad=True \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--is_local=${IS_LOCAL} \
|
||||
--share_rnn_weights=True \
|
||||
--train_manifest='/local.manifest.train' \
|
||||
--dev_manifest='/local.manifest.dev' \
|
||||
--mean_std_path='data/librispeech/mean_std.npz' \
|
||||
--vocab_path='data/librispeech/vocab.txt' \
|
||||
--output_model_dir='./checkpoints' \
|
||||
--output_model_dir=${MODEL_PATH} \
|
||||
--augment_conf_path='conf/augmentation.config' \
|
||||
--specgram_type='linear' \
|
||||
--shuffle_method='batch_shuffle_clipped' \
|
||||
2>&1 | tee ./logs/train.log
|
@ -0,0 +1,22 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
mkdir cloud_manifests
|
||||
|
||||
IN_MANIFESTS="../data/librispeech/manifest.train ../data/librispeech/manifest.dev-clean ../data/librispeech/manifest.test-clean"
|
||||
OUT_MANIFESTS="cloud_manifests/cloud.manifest.train cloud_manifests/cloud.manifest.dev cloud_manifests/cloud.manifest.test"
|
||||
CLOUD_DATA_DIR="/pfs/dlnel/home/USERNAME/deepspeech2/data/librispeech"
|
||||
NUM_SHARDS=50
|
||||
|
||||
python upload_data.py \
|
||||
--in_manifest_paths ${IN_MANIFESTS} \
|
||||
--out_manifest_paths ${OUT_MANIFESTS} \
|
||||
--cloud_data_dir ${CLOUD_DATA_DIR} \
|
||||
--num_shards ${NUM_SHARDS}
|
||||
|
||||
if [ $? -ne 0 ]
|
||||
then
|
||||
echo "Upload Data Failed!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "All Done."
|
@ -0,0 +1,41 @@
|
||||
"""This tool is used for splitting data into each node of
|
||||
paddlecloud. This script should be called in paddlecloud.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--in_manifest_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Input manifest path for all nodes.")
|
||||
parser.add_argument(
|
||||
"--out_manifest_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Output manifest file path for current node.")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def split_data(in_manifest_path, out_manifest_path):
|
||||
with open("/trainer_id", "r") as f:
|
||||
trainer_id = int(f.readline()[:-1])
|
||||
with open("/trainer_count", "r") as f:
|
||||
trainer_count = int(f.readline()[:-1])
|
||||
|
||||
out_manifest = []
|
||||
for index, json_line in enumerate(open(in_manifest_path, 'r')):
|
||||
if (index % trainer_count) == trainer_id:
|
||||
out_manifest.append("%s\n" % json_line.strip())
|
||||
with open(out_manifest_path, 'w') as f:
|
||||
f.writelines(out_manifest)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
split_data(args.in_manifest_path, args.out_manifest_path)
|
@ -0,0 +1,129 @@
|
||||
"""This script is for uploading data for DeepSpeech2 training on paddlecloud.
|
||||
|
||||
Steps:
|
||||
1. Read original manifests and extract local sound files.
|
||||
2. Tar all local sound files into multiple tar files and upload them.
|
||||
3. Modify original manifests with updated paths in cloud filesystem.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import os
|
||||
import tarfile
|
||||
import sys
|
||||
import argparse
|
||||
import shutil
|
||||
from subprocess import call
|
||||
import _init_paths
|
||||
from data_utils.utils import read_manifest
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--in_manifest_paths",
|
||||
default=[
|
||||
"../datasets/manifest.train", "../datasets/manifest.dev",
|
||||
"../datasets/manifest.test"
|
||||
],
|
||||
type=str,
|
||||
nargs='+',
|
||||
help="Local filepaths of input manifests to load, pack and upload."
|
||||
"(default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--out_manifest_paths",
|
||||
default=[
|
||||
"./cloud.manifest.train", "./cloud.manifest.dev",
|
||||
"./cloud.manifest.test"
|
||||
],
|
||||
type=str,
|
||||
nargs='+',
|
||||
help="Local filepaths of modified manifests to write to. "
|
||||
"(default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--cloud_data_dir",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Destination directory on paddlecloud to upload data to.")
|
||||
parser.add_argument(
|
||||
"--num_shards",
|
||||
default=10,
|
||||
type=int,
|
||||
help="Number of parts to split data to. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--local_tmp_dir",
|
||||
default="./tmp/",
|
||||
type=str,
|
||||
help="Local directory for storing temporary data. (default: %(default)s)")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def upload_data(in_manifest_path_list, out_manifest_path_list, local_tmp_dir,
|
||||
upload_tar_dir, num_shards):
|
||||
"""Extract and pack sound files listed in the manifest files into multple
|
||||
tar files and upload them to padldecloud. Besides, generate new manifest
|
||||
files with updated paths in paddlecloud.
|
||||
"""
|
||||
# compute total audio number
|
||||
total_line = 0
|
||||
for manifest_path in in_manifest_path_list:
|
||||
with open(manifest_path, 'r') as f:
|
||||
total_line += len(f.readlines())
|
||||
line_per_tar = (total_line // num_shards) + 1
|
||||
|
||||
# pack and upload shard by shard
|
||||
line_count, tar_file = 0, None
|
||||
for manifest_path, out_manifest_path in zip(in_manifest_path_list,
|
||||
out_manifest_path_list):
|
||||
manifest = read_manifest(manifest_path)
|
||||
out_manifest = []
|
||||
for json_data in manifest:
|
||||
sound_filepath = json_data['audio_filepath']
|
||||
sound_filename = os.path.basename(sound_filepath)
|
||||
if line_count % line_per_tar == 0:
|
||||
if tar_file != None:
|
||||
tar_file.close()
|
||||
pcloud_cp(tar_path, upload_tar_dir)
|
||||
os.remove(tar_path)
|
||||
tar_name = 'part-%s-of-%s.tar' % (
|
||||
str(line_count // line_per_tar).zfill(5),
|
||||
str(num_shards).zfill(5))
|
||||
tar_path = os.path.join(local_tmp_dir, tar_name)
|
||||
tar_file = tarfile.open(tar_path, 'w')
|
||||
tar_file.add(sound_filepath, arcname=sound_filename)
|
||||
line_count += 1
|
||||
json_data['audio_filepath'] = "tar:%s#%s" % (
|
||||
os.path.join(upload_tar_dir, tar_name), sound_filename)
|
||||
out_manifest.append("%s\n" % json.dumps(json_data))
|
||||
with open(out_manifest_path, 'w') as f:
|
||||
f.writelines(out_manifest)
|
||||
pcloud_cp(out_manifest_path, upload_tar_dir)
|
||||
tar_file.close()
|
||||
pcloud_cp(tar_path, upload_tar_dir)
|
||||
os.remove(tar_path)
|
||||
|
||||
|
||||
def pcloud_mkdir(dir):
|
||||
"""Make directory in PaddleCloud filesystem.
|
||||
"""
|
||||
if call(['paddlecloud', 'mkdir', dir]) != 0:
|
||||
raise IOError("PaddleCloud mkdir failed: %s." % dir)
|
||||
|
||||
|
||||
def pcloud_cp(src, dst):
|
||||
"""Copy src from local filesytem to dst in PaddleCloud filesystem,
|
||||
or downlowd src from PaddleCloud filesystem to dst in local filesystem.
|
||||
"""
|
||||
if call(['paddlecloud', 'cp', src, dst]) != 0:
|
||||
raise IOError("PaddleCloud cp failed: from [%s] to [%s]." % (src, dst))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if not os.path.exists(args.local_tmp_dir):
|
||||
os.makedirs(args.local_tmp_dir)
|
||||
pcloud_mkdir(args.cloud_data_dir)
|
||||
|
||||
upload_data(args.in_manifest_paths, args.out_manifest_paths,
|
||||
args.local_tmp_dir, args.cloud_data_dir, args.num_shards)
|
||||
|
||||
shutil.rmtree(args.local_tmp_dir)
|
@ -1,63 +0,0 @@
|
||||
"""Compute mean and std for feature normalizer, and save to file."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
from data_utils.normalizer import FeatureNormalizer
|
||||
from data_utils.augmentor.augmentation import AugmentationPipeline
|
||||
from data_utils.featurizer.audio_featurizer import AudioFeaturizer
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Computing mean and stddev for feature normalizer.')
|
||||
parser.add_argument(
|
||||
"--specgram_type",
|
||||
default='linear',
|
||||
type=str,
|
||||
help="Feature type of audio data: 'linear' (power spectrum)"
|
||||
" or 'mfcc'. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--manifest_path",
|
||||
default='datasets/manifest.train',
|
||||
type=str,
|
||||
help="Manifest path for computing normalizer's mean and stddev."
|
||||
"(default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--num_samples",
|
||||
default=2000,
|
||||
type=int,
|
||||
help="Number of samples for computing mean and stddev. "
|
||||
"(default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--augmentation_config",
|
||||
default='{}',
|
||||
type=str,
|
||||
help="Augmentation configuration in json-format. "
|
||||
"(default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--output_file",
|
||||
default='mean_std.npz',
|
||||
type=str,
|
||||
help="Filepath to write mean and std to (.npz)."
|
||||
"(default: %(default)s)")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
augmentation_pipeline = AugmentationPipeline(args.augmentation_config)
|
||||
audio_featurizer = AudioFeaturizer(specgram_type=args.specgram_type)
|
||||
|
||||
def augment_and_featurize(audio_segment):
|
||||
augmentation_pipeline.transform_audio(audio_segment)
|
||||
return audio_featurizer.featurize(audio_segment)
|
||||
|
||||
normalizer = FeatureNormalizer(
|
||||
mean_std_filepath=None,
|
||||
manifest_path=args.manifest_path,
|
||||
featurize_func=augment_and_featurize,
|
||||
num_samples=args.num_samples)
|
||||
normalizer.write_to_file(args.output_file)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,109 @@
|
||||
"""Prepare Aishell mandarin dataset
|
||||
|
||||
Download, unpack and create manifest files.
|
||||
Manifest file is a json-format file with each line containing the
|
||||
meta data (i.e. audio filepath, transcript and audio duration)
|
||||
of each audio file in the data set.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import codecs
|
||||
import soundfile
|
||||
import json
|
||||
import argparse
|
||||
from data_utils.utility import download, unpack
|
||||
|
||||
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
|
||||
|
||||
URL_ROOT = 'http://www.openslr.org/resources/33'
|
||||
DATA_URL = URL_ROOT + '/data_aishell.tgz'
|
||||
MD5_DATA = '2f494334227864a8a8fec932999db9d8'
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--target_dir",
|
||||
default=DATA_HOME + "/Aishell",
|
||||
type=str,
|
||||
help="Directory to save the dataset. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--manifest_prefix",
|
||||
default="manifest",
|
||||
type=str,
|
||||
help="Filepath prefix for output manifests. (default: %(default)s)")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def create_manifest(data_dir, manifest_path_prefix):
|
||||
print("Creating manifest %s ..." % manifest_path_prefix)
|
||||
json_lines = []
|
||||
transcript_path = os.path.join(data_dir, 'transcript',
|
||||
'aishell_transcript_v0.8.txt')
|
||||
transcript_dict = {}
|
||||
for line in codecs.open(transcript_path, 'r', 'utf-8'):
|
||||
line = line.strip()
|
||||
if line == '': continue
|
||||
audio_id, text = line.split(' ', 1)
|
||||
# remove withespace
|
||||
text = ''.join(text.split())
|
||||
transcript_dict[audio_id] = text
|
||||
|
||||
data_types = ['train', 'dev', 'test']
|
||||
for type in data_types:
|
||||
audio_dir = os.path.join(data_dir, 'wav', type)
|
||||
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
|
||||
for fname in filelist:
|
||||
audio_path = os.path.join(subfolder, fname)
|
||||
audio_id = fname[:-4]
|
||||
# if no transcription for audio then skipped
|
||||
if audio_id not in transcript_dict:
|
||||
continue
|
||||
audio_data, samplerate = soundfile.read(audio_path)
|
||||
duration = float(len(audio_data) / samplerate)
|
||||
text = transcript_dict[audio_id]
|
||||
json_lines.append(
|
||||
json.dumps(
|
||||
{
|
||||
'audio_filepath': audio_path,
|
||||
'duration': duration,
|
||||
'text': text
|
||||
},
|
||||
ensure_ascii=False))
|
||||
manifest_path = manifest_path_prefix + '.' + type
|
||||
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
|
||||
for line in json_lines:
|
||||
fout.write(line + '\n')
|
||||
|
||||
|
||||
def prepare_dataset(url, md5sum, target_dir, manifest_path):
|
||||
"""Download, unpack and create manifest file."""
|
||||
data_dir = os.path.join(target_dir, 'data_aishell')
|
||||
if not os.path.exists(data_dir):
|
||||
filepath = download(url, md5sum, target_dir)
|
||||
unpack(filepath, target_dir)
|
||||
# unpack all audio tar files
|
||||
audio_dir = os.path.join(data_dir, 'wav')
|
||||
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
|
||||
for ftar in filelist:
|
||||
unpack(os.path.join(subfolder, ftar), subfolder, True)
|
||||
else:
|
||||
print("Skip downloading and unpacking. Data already exists in %s." %
|
||||
target_dir)
|
||||
create_manifest(data_dir, manifest_path)
|
||||
|
||||
|
||||
def main():
|
||||
if args.target_dir.startswith('~'):
|
||||
args.target_dir = os.path.expanduser(args.target_dir)
|
||||
|
||||
prepare_dataset(
|
||||
url=DATA_URL,
|
||||
md5sum=MD5_DATA,
|
||||
target_dir=args.target_dir,
|
||||
manifest_path=args.manifest_prefix)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,63 @@
|
||||
"""Contains data helper functions."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import codecs
|
||||
import os
|
||||
import tarfile
|
||||
from paddle.v2.dataset.common import md5file
|
||||
|
||||
|
||||
def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0):
|
||||
"""Load and parse manifest file.
|
||||
|
||||
Instances with durations outside [min_duration, max_duration] will be
|
||||
filtered out.
|
||||
|
||||
:param manifest_path: Manifest file to load and parse.
|
||||
:type manifest_path: basestring
|
||||
:param max_duration: Maximal duration in seconds for instance filter.
|
||||
:type max_duration: float
|
||||
:param min_duration: Minimal duration in seconds for instance filter.
|
||||
:type min_duration: float
|
||||
:return: Manifest parsing results. List of dict.
|
||||
:rtype: list
|
||||
:raises IOError: If failed to parse the manifest.
|
||||
"""
|
||||
manifest = []
|
||||
for json_line in codecs.open(manifest_path, 'r', 'utf-8'):
|
||||
try:
|
||||
json_data = json.loads(json_line)
|
||||
except Exception as e:
|
||||
raise IOError("Error reading manifest: %s" % str(e))
|
||||
if (json_data["duration"] <= max_duration and
|
||||
json_data["duration"] >= min_duration):
|
||||
manifest.append(json_data)
|
||||
return manifest
|
||||
|
||||
|
||||
def download(url, md5sum, target_dir):
|
||||
"""Download file from url to target_dir, and check md5sum."""
|
||||
if not os.path.exists(target_dir): os.makedirs(target_dir)
|
||||
filepath = os.path.join(target_dir, url.split("/")[-1])
|
||||
if not (os.path.exists(filepath) and md5file(filepath) == md5sum):
|
||||
print("Downloading %s ..." % url)
|
||||
os.system("wget -c " + url + " -P " + target_dir)
|
||||
print("\nMD5 Chesksum %s ..." % filepath)
|
||||
if not md5file(filepath) == md5sum:
|
||||
raise RuntimeError("MD5 checksum failed.")
|
||||
else:
|
||||
print("File exists, skip downloading. (%s)" % filepath)
|
||||
return filepath
|
||||
|
||||
|
||||
def unpack(filepath, target_dir, rm_tar=False):
|
||||
"""Unpack the file to the target_dir."""
|
||||
print("Unpacking %s ..." % filepath)
|
||||
tar = tarfile.open(filepath)
|
||||
tar.extractall(target_dir)
|
||||
tar.close()
|
||||
if rm_tar == True:
|
||||
os.remove(filepath)
|
@ -1,34 +0,0 @@
|
||||
"""Contains data helper functions."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
|
||||
|
||||
def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0):
|
||||
"""Load and parse manifest file.
|
||||
|
||||
Instances with durations outside [min_duration, max_duration] will be
|
||||
filtered out.
|
||||
|
||||
:param manifest_path: Manifest file to load and parse.
|
||||
:type manifest_path: basestring
|
||||
:param max_duration: Maximal duration in seconds for instance filter.
|
||||
:type max_duration: float
|
||||
:param min_duration: Minimal duration in seconds for instance filter.
|
||||
:type min_duration: float
|
||||
:return: Manifest parsing results. List of dict.
|
||||
:rtype: list
|
||||
:raises IOError: If failed to parse the manifest.
|
||||
"""
|
||||
manifest = []
|
||||
for json_line in open(manifest_path):
|
||||
try:
|
||||
json_data = json.loads(json_line)
|
||||
except Exception as e:
|
||||
raise IOError("Error reading manifest: %s" % str(e))
|
||||
if (json_data["duration"] <= max_duration and
|
||||
json_data["duration"] >= min_duration):
|
||||
manifest.append(json_data)
|
||||
return manifest
|
@ -1,13 +0,0 @@
|
||||
cd librispeech
|
||||
python librispeech.py
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Prepare LibriSpeech failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
cd -
|
||||
|
||||
cat librispeech/manifest.train* | shuf > manifest.train
|
||||
cat librispeech/manifest.dev-clean > manifest.dev
|
||||
cat librispeech/manifest.test-clean > manifest.test
|
||||
|
||||
echo "All done."
|
@ -1,10 +0,0 @@
|
||||
cd noise
|
||||
python chime3_background.py
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Prepare CHiME3 background noise failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
cd -
|
||||
|
||||
cat noise/manifest.* > manifest.noise
|
||||
echo "All done."
|
@ -1,28 +0,0 @@
|
||||
'
|
||||
|
||||
a
|
||||
b
|
||||
c
|
||||
d
|
||||
e
|
||||
f
|
||||
g
|
||||
h
|
||||
i
|
||||
j
|
||||
k
|
||||
l
|
||||
m
|
||||
n
|
||||
o
|
||||
p
|
||||
q
|
||||
r
|
||||
s
|
||||
t
|
||||
u
|
||||
v
|
||||
w
|
||||
x
|
||||
y
|
||||
z
|
@ -0,0 +1,19 @@
|
||||
"""Set up paths for DS2"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os.path
|
||||
import sys
|
||||
|
||||
|
||||
def add_path(path):
|
||||
if path not in sys.path:
|
||||
sys.path.insert(0, path)
|
||||
|
||||
|
||||
this_dir = os.path.dirname(__file__)
|
||||
|
||||
# Add project path to PYTHONPATH
|
||||
proj_path = os.path.join(this_dir, '..')
|
||||
add_path(proj_path)
|
@ -0,0 +1,204 @@
|
||||
#include "ctc_beam_search_decoder.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
|
||||
#include "ThreadPool.h"
|
||||
#include "fst/fstlib.h"
|
||||
|
||||
#include "decoder_utils.h"
|
||||
#include "path_trie.h"
|
||||
|
||||
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
||||
|
||||
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
|
||||
const std::vector<std::vector<double>> &probs_seq,
|
||||
const std::vector<std::string> &vocabulary,
|
||||
size_t beam_size,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
Scorer *ext_scorer) {
|
||||
// dimension check
|
||||
size_t num_time_steps = probs_seq.size();
|
||||
for (size_t i = 0; i < num_time_steps; ++i) {
|
||||
VALID_CHECK_EQ(probs_seq[i].size(),
|
||||
vocabulary.size() + 1,
|
||||
"The shape of probs_seq does not match with "
|
||||
"the shape of the vocabulary");
|
||||
}
|
||||
|
||||
// assign blank id
|
||||
size_t blank_id = vocabulary.size();
|
||||
|
||||
// assign space id
|
||||
auto it = std::find(vocabulary.begin(), vocabulary.end(), " ");
|
||||
int space_id = it - vocabulary.begin();
|
||||
// if no space in vocabulary
|
||||
if ((size_t)space_id >= vocabulary.size()) {
|
||||
space_id = -2;
|
||||
}
|
||||
|
||||
// init prefixes' root
|
||||
PathTrie root;
|
||||
root.score = root.log_prob_b_prev = 0.0;
|
||||
std::vector<PathTrie *> prefixes;
|
||||
prefixes.push_back(&root);
|
||||
|
||||
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
||||
auto fst_dict = static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
|
||||
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
|
||||
root.set_dictionary(dict_ptr);
|
||||
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
|
||||
root.set_matcher(matcher);
|
||||
}
|
||||
|
||||
// prefix search over time
|
||||
for (size_t time_step = 0; time_step < num_time_steps; ++time_step) {
|
||||
auto &prob = probs_seq[time_step];
|
||||
|
||||
float min_cutoff = -NUM_FLT_INF;
|
||||
bool full_beam = false;
|
||||
if (ext_scorer != nullptr) {
|
||||
size_t num_prefixes = std::min(prefixes.size(), beam_size);
|
||||
std::sort(
|
||||
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
|
||||
min_cutoff = prefixes[num_prefixes - 1]->score +
|
||||
std::log(prob[blank_id]) - std::max(0.0, ext_scorer->beta);
|
||||
full_beam = (num_prefixes == beam_size);
|
||||
}
|
||||
|
||||
std::vector<std::pair<size_t, float>> log_prob_idx =
|
||||
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
|
||||
// loop over chars
|
||||
for (size_t index = 0; index < log_prob_idx.size(); index++) {
|
||||
auto c = log_prob_idx[index].first;
|
||||
auto log_prob_c = log_prob_idx[index].second;
|
||||
|
||||
for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) {
|
||||
auto prefix = prefixes[i];
|
||||
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
|
||||
break;
|
||||
}
|
||||
// blank
|
||||
if (c == blank_id) {
|
||||
prefix->log_prob_b_cur =
|
||||
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
|
||||
continue;
|
||||
}
|
||||
// repeated character
|
||||
if (c == prefix->character) {
|
||||
prefix->log_prob_nb_cur = log_sum_exp(
|
||||
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
|
||||
}
|
||||
// get new prefix
|
||||
auto prefix_new = prefix->get_path_trie(c);
|
||||
|
||||
if (prefix_new != nullptr) {
|
||||
float log_p = -NUM_FLT_INF;
|
||||
|
||||
if (c == prefix->character &&
|
||||
prefix->log_prob_b_prev > -NUM_FLT_INF) {
|
||||
log_p = log_prob_c + prefix->log_prob_b_prev;
|
||||
} else if (c != prefix->character) {
|
||||
log_p = log_prob_c + prefix->score;
|
||||
}
|
||||
|
||||
// language model scoring
|
||||
if (ext_scorer != nullptr &&
|
||||
(c == space_id || ext_scorer->is_character_based())) {
|
||||
PathTrie *prefix_toscore = nullptr;
|
||||
// skip scoring the space
|
||||
if (ext_scorer->is_character_based()) {
|
||||
prefix_toscore = prefix_new;
|
||||
} else {
|
||||
prefix_toscore = prefix;
|
||||
}
|
||||
|
||||
double score = 0.0;
|
||||
std::vector<std::string> ngram;
|
||||
ngram = ext_scorer->make_ngram(prefix_toscore);
|
||||
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
|
||||
log_p += score;
|
||||
log_p += ext_scorer->beta;
|
||||
}
|
||||
prefix_new->log_prob_nb_cur =
|
||||
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
|
||||
}
|
||||
} // end of loop over prefix
|
||||
} // end of loop over vocabulary
|
||||
|
||||
prefixes.clear();
|
||||
// update log probs
|
||||
root.iterate_to_vec(prefixes);
|
||||
|
||||
// only preserve top beam_size prefixes
|
||||
if (prefixes.size() >= beam_size) {
|
||||
std::nth_element(prefixes.begin(),
|
||||
prefixes.begin() + beam_size,
|
||||
prefixes.end(),
|
||||
prefix_compare);
|
||||
for (size_t i = beam_size; i < prefixes.size(); ++i) {
|
||||
prefixes[i]->remove();
|
||||
}
|
||||
}
|
||||
} // end of loop over time
|
||||
|
||||
// compute aproximate ctc score as the return score, without affecting the
|
||||
// return order of decoding result. To delete when decoder gets stable.
|
||||
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
||||
double approx_ctc = prefixes[i]->score;
|
||||
if (ext_scorer != nullptr) {
|
||||
std::vector<int> output;
|
||||
prefixes[i]->get_path_vec(output);
|
||||
auto prefix_length = output.size();
|
||||
auto words = ext_scorer->split_labels(output);
|
||||
// remove word insert
|
||||
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
|
||||
// remove language model weight:
|
||||
approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
|
||||
}
|
||||
prefixes[i]->approx_ctc = approx_ctc;
|
||||
}
|
||||
|
||||
return get_beam_search_result(prefixes, vocabulary, beam_size);
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::vector<std::pair<double, std::string>>>
|
||||
ctc_beam_search_decoder_batch(
|
||||
const std::vector<std::vector<std::vector<double>>> &probs_split,
|
||||
const std::vector<std::string> &vocabulary,
|
||||
size_t beam_size,
|
||||
size_t num_processes,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
Scorer *ext_scorer) {
|
||||
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
|
||||
// thread pool
|
||||
ThreadPool pool(num_processes);
|
||||
// number of samples
|
||||
size_t batch_size = probs_split.size();
|
||||
|
||||
// enqueue the tasks of decoding
|
||||
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
res.emplace_back(pool.enqueue(ctc_beam_search_decoder,
|
||||
probs_split[i],
|
||||
vocabulary,
|
||||
beam_size,
|
||||
cutoff_prob,
|
||||
cutoff_top_n,
|
||||
ext_scorer));
|
||||
}
|
||||
|
||||
// get decoding results
|
||||
std::vector<std::vector<std::pair<double, std::string>>> batch_results;
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
batch_results.emplace_back(res[i].get());
|
||||
}
|
||||
return batch_results;
|
||||
}
|
@ -0,0 +1,61 @@
|
||||
#ifndef CTC_BEAM_SEARCH_DECODER_H_
|
||||
#define CTC_BEAM_SEARCH_DECODER_H_
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "scorer.h"
|
||||
|
||||
/* CTC Beam Search Decoder
|
||||
|
||||
* Parameters:
|
||||
* probs_seq: 2-D vector that each element is a vector of probabilities
|
||||
* over vocabulary of one time step.
|
||||
* vocabulary: A vector of vocabulary.
|
||||
* beam_size: The width of beam search.
|
||||
* cutoff_prob: Cutoff probability for pruning.
|
||||
* cutoff_top_n: Cutoff number for pruning.
|
||||
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
||||
* n-gram language model scoring and word insertion term.
|
||||
* Default null, decoding the input sample without scorer.
|
||||
* Return:
|
||||
* A vector that each element is a pair of score and decoding result,
|
||||
* in desending order.
|
||||
*/
|
||||
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
|
||||
const std::vector<std::vector<double>> &probs_seq,
|
||||
const std::vector<std::string> &vocabulary,
|
||||
size_t beam_size,
|
||||
double cutoff_prob = 1.0,
|
||||
size_t cutoff_top_n = 40,
|
||||
Scorer *ext_scorer = nullptr);
|
||||
|
||||
/* CTC Beam Search Decoder for batch data
|
||||
|
||||
* Parameters:
|
||||
* probs_seq: 3-D vector that each element is a 2-D vector that can be used
|
||||
* by ctc_beam_search_decoder().
|
||||
* vocabulary: A vector of vocabulary.
|
||||
* beam_size: The width of beam search.
|
||||
* num_processes: Number of threads for beam search.
|
||||
* cutoff_prob: Cutoff probability for pruning.
|
||||
* cutoff_top_n: Cutoff number for pruning.
|
||||
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
||||
* n-gram language model scoring and word insertion term.
|
||||
* Default null, decoding the input sample without scorer.
|
||||
* Return:
|
||||
* A 2-D vector that each element is a vector of beam search decoding
|
||||
* result for one audio sample.
|
||||
*/
|
||||
std::vector<std::vector<std::pair<double, std::string>>>
|
||||
ctc_beam_search_decoder_batch(
|
||||
const std::vector<std::vector<std::vector<double>>> &probs_split,
|
||||
const std::vector<std::string> &vocabulary,
|
||||
size_t beam_size,
|
||||
size_t num_processes,
|
||||
double cutoff_prob = 1.0,
|
||||
size_t cutoff_top_n = 40,
|
||||
Scorer *ext_scorer = nullptr);
|
||||
|
||||
#endif // CTC_BEAM_SEARCH_DECODER_H_
|
@ -0,0 +1,45 @@
|
||||
#include "ctc_greedy_decoder.h"
|
||||
#include "decoder_utils.h"
|
||||
|
||||
std::string ctc_greedy_decoder(
|
||||
const std::vector<std::vector<double>> &probs_seq,
|
||||
const std::vector<std::string> &vocabulary) {
|
||||
// dimension check
|
||||
size_t num_time_steps = probs_seq.size();
|
||||
for (size_t i = 0; i < num_time_steps; ++i) {
|
||||
VALID_CHECK_EQ(probs_seq[i].size(),
|
||||
vocabulary.size() + 1,
|
||||
"The shape of probs_seq does not match with "
|
||||
"the shape of the vocabulary");
|
||||
}
|
||||
|
||||
size_t blank_id = vocabulary.size();
|
||||
|
||||
std::vector<size_t> max_idx_vec(num_time_steps, 0);
|
||||
std::vector<size_t> idx_vec;
|
||||
for (size_t i = 0; i < num_time_steps; ++i) {
|
||||
double max_prob = 0.0;
|
||||
size_t max_idx = 0;
|
||||
const std::vector<double> &probs_step = probs_seq[i];
|
||||
for (size_t j = 0; j < probs_step.size(); ++j) {
|
||||
if (max_prob < probs_step[j]) {
|
||||
max_idx = j;
|
||||
max_prob = probs_step[j];
|
||||
}
|
||||
}
|
||||
// id with maximum probability in current time step
|
||||
max_idx_vec[i] = max_idx;
|
||||
// deduplicate
|
||||
if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) {
|
||||
idx_vec.push_back(max_idx_vec[i]);
|
||||
}
|
||||
}
|
||||
|
||||
std::string best_path_result;
|
||||
for (size_t i = 0; i < idx_vec.size(); ++i) {
|
||||
if (idx_vec[i] != blank_id) {
|
||||
best_path_result += vocabulary[idx_vec[i]];
|
||||
}
|
||||
}
|
||||
return best_path_result;
|
||||
}
|
@ -0,0 +1,20 @@
|
||||
#ifndef CTC_GREEDY_DECODER_H
|
||||
#define CTC_GREEDY_DECODER_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
/* CTC Greedy (Best Path) Decoder
|
||||
*
|
||||
* Parameters:
|
||||
* probs_seq: 2-D vector that each element is a vector of probabilities
|
||||
* over vocabulary of one time step.
|
||||
* vocabulary: A vector of vocabulary.
|
||||
* Return:
|
||||
* The decoding result in string
|
||||
*/
|
||||
std::string ctc_greedy_decoder(
|
||||
const std::vector<std::vector<double>>& probs_seq,
|
||||
const std::vector<std::string>& vocabulary);
|
||||
|
||||
#endif // CTC_GREEDY_DECODER_H
|
@ -0,0 +1,176 @@
|
||||
#include "decoder_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
|
||||
std::vector<std::pair<size_t, float>> get_pruned_log_probs(
|
||||
const std::vector<double> &prob_step,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n) {
|
||||
std::vector<std::pair<int, double>> prob_idx;
|
||||
for (size_t i = 0; i < prob_step.size(); ++i) {
|
||||
prob_idx.push_back(std::pair<int, double>(i, prob_step[i]));
|
||||
}
|
||||
// pruning of vacobulary
|
||||
size_t cutoff_len = prob_step.size();
|
||||
if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) {
|
||||
std::sort(
|
||||
prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>);
|
||||
if (cutoff_prob < 1.0) {
|
||||
double cum_prob = 0.0;
|
||||
cutoff_len = 0;
|
||||
for (size_t i = 0; i < prob_idx.size(); ++i) {
|
||||
cum_prob += prob_idx[i].second;
|
||||
cutoff_len += 1;
|
||||
if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n) break;
|
||||
}
|
||||
}
|
||||
prob_idx = std::vector<std::pair<int, double>>(
|
||||
prob_idx.begin(), prob_idx.begin() + cutoff_len);
|
||||
}
|
||||
std::vector<std::pair<size_t, float>> log_prob_idx;
|
||||
for (size_t i = 0; i < cutoff_len; ++i) {
|
||||
log_prob_idx.push_back(std::pair<int, float>(
|
||||
prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN)));
|
||||
}
|
||||
return log_prob_idx;
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::pair<double, std::string>> get_beam_search_result(
|
||||
const std::vector<PathTrie *> &prefixes,
|
||||
const std::vector<std::string> &vocabulary,
|
||||
size_t beam_size) {
|
||||
// allow for the post processing
|
||||
std::vector<PathTrie *> space_prefixes;
|
||||
if (space_prefixes.empty()) {
|
||||
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
||||
space_prefixes.push_back(prefixes[i]);
|
||||
}
|
||||
}
|
||||
|
||||
std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
|
||||
std::vector<std::pair<double, std::string>> output_vecs;
|
||||
for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) {
|
||||
std::vector<int> output;
|
||||
space_prefixes[i]->get_path_vec(output);
|
||||
// convert index to string
|
||||
std::string output_str;
|
||||
for (size_t j = 0; j < output.size(); j++) {
|
||||
output_str += vocabulary[output[j]];
|
||||
}
|
||||
std::pair<double, std::string> output_pair(-space_prefixes[i]->approx_ctc,
|
||||
output_str);
|
||||
output_vecs.emplace_back(output_pair);
|
||||
}
|
||||
|
||||
return output_vecs;
|
||||
}
|
||||
|
||||
size_t get_utf8_str_len(const std::string &str) {
|
||||
size_t str_len = 0;
|
||||
for (char c : str) {
|
||||
str_len += ((c & 0xc0) != 0x80);
|
||||
}
|
||||
return str_len;
|
||||
}
|
||||
|
||||
std::vector<std::string> split_utf8_str(const std::string &str) {
|
||||
std::vector<std::string> result;
|
||||
std::string out_str;
|
||||
|
||||
for (char c : str) {
|
||||
if ((c & 0xc0) != 0x80) // new UTF-8 character
|
||||
{
|
||||
if (!out_str.empty()) {
|
||||
result.push_back(out_str);
|
||||
out_str.clear();
|
||||
}
|
||||
}
|
||||
|
||||
out_str.append(1, c);
|
||||
}
|
||||
result.push_back(out_str);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<std::string> split_str(const std::string &s,
|
||||
const std::string &delim) {
|
||||
std::vector<std::string> result;
|
||||
std::size_t start = 0, delim_len = delim.size();
|
||||
while (true) {
|
||||
std::size_t end = s.find(delim, start);
|
||||
if (end == std::string::npos) {
|
||||
if (start < s.size()) {
|
||||
result.push_back(s.substr(start));
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (end > start) {
|
||||
result.push_back(s.substr(start, end - start));
|
||||
}
|
||||
start = end + delim_len;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool prefix_compare(const PathTrie *x, const PathTrie *y) {
|
||||
if (x->score == y->score) {
|
||||
if (x->character == y->character) {
|
||||
return false;
|
||||
} else {
|
||||
return (x->character < y->character);
|
||||
}
|
||||
} else {
|
||||
return x->score > y->score;
|
||||
}
|
||||
}
|
||||
|
||||
void add_word_to_fst(const std::vector<int> &word,
|
||||
fst::StdVectorFst *dictionary) {
|
||||
if (dictionary->NumStates() == 0) {
|
||||
fst::StdVectorFst::StateId start = dictionary->AddState();
|
||||
assert(start == 0);
|
||||
dictionary->SetStart(start);
|
||||
}
|
||||
fst::StdVectorFst::StateId src = dictionary->Start();
|
||||
fst::StdVectorFst::StateId dst;
|
||||
for (auto c : word) {
|
||||
dst = dictionary->AddState();
|
||||
dictionary->AddArc(src, fst::StdArc(c, c, 0, dst));
|
||||
src = dst;
|
||||
}
|
||||
dictionary->SetFinal(dst, fst::StdArc::Weight::One());
|
||||
}
|
||||
|
||||
bool add_word_to_dictionary(
|
||||
const std::string &word,
|
||||
const std::unordered_map<std::string, int> &char_map,
|
||||
bool add_space,
|
||||
int SPACE_ID,
|
||||
fst::StdVectorFst *dictionary) {
|
||||
auto characters = split_utf8_str(word);
|
||||
|
||||
std::vector<int> int_word;
|
||||
|
||||
for (auto &c : characters) {
|
||||
if (c == " ") {
|
||||
int_word.push_back(SPACE_ID);
|
||||
} else {
|
||||
auto int_c = char_map.find(c);
|
||||
if (int_c != char_map.end()) {
|
||||
int_word.push_back(int_c->second);
|
||||
} else {
|
||||
return false; // return without adding
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (add_space) {
|
||||
int_word.push_back(SPACE_ID);
|
||||
}
|
||||
|
||||
add_word_to_fst(int_word, dictionary);
|
||||
return true; // return with successful adding
|
||||
}
|
@ -0,0 +1,94 @@
|
||||
#ifndef DECODER_UTILS_H_
|
||||
#define DECODER_UTILS_H_
|
||||
|
||||
#include <utility>
|
||||
#include "fst/log.h"
|
||||
#include "path_trie.h"
|
||||
|
||||
const float NUM_FLT_INF = std::numeric_limits<float>::max();
|
||||
const float NUM_FLT_MIN = std::numeric_limits<float>::min();
|
||||
|
||||
// inline function for validation check
|
||||
inline void check(
|
||||
bool x, const char *expr, const char *file, int line, const char *err) {
|
||||
if (!x) {
|
||||
std::cout << "[" << file << ":" << line << "] ";
|
||||
LOG(FATAL) << "\"" << expr << "\" check failed. " << err;
|
||||
}
|
||||
}
|
||||
|
||||
#define VALID_CHECK(x, info) \
|
||||
check(static_cast<bool>(x), #x, __FILE__, __LINE__, info)
|
||||
#define VALID_CHECK_EQ(x, y, info) VALID_CHECK((x) == (y), info)
|
||||
#define VALID_CHECK_GT(x, y, info) VALID_CHECK((x) > (y), info)
|
||||
#define VALID_CHECK_LT(x, y, info) VALID_CHECK((x) < (y), info)
|
||||
|
||||
|
||||
// Function template for comparing two pairs
|
||||
template <typename T1, typename T2>
|
||||
bool pair_comp_first_rev(const std::pair<T1, T2> &a,
|
||||
const std::pair<T1, T2> &b) {
|
||||
return a.first > b.first;
|
||||
}
|
||||
|
||||
// Function template for comparing two pairs
|
||||
template <typename T1, typename T2>
|
||||
bool pair_comp_second_rev(const std::pair<T1, T2> &a,
|
||||
const std::pair<T1, T2> &b) {
|
||||
return a.second > b.second;
|
||||
}
|
||||
|
||||
// Return the sum of two probabilities in log scale
|
||||
template <typename T>
|
||||
T log_sum_exp(const T &x, const T &y) {
|
||||
static T num_min = -std::numeric_limits<T>::max();
|
||||
if (x <= num_min) return y;
|
||||
if (y <= num_min) return x;
|
||||
T xmax = std::max(x, y);
|
||||
return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax;
|
||||
}
|
||||
|
||||
// Get pruned probability vector for each time step's beam search
|
||||
std::vector<std::pair<size_t, float>> get_pruned_log_probs(
|
||||
const std::vector<double> &prob_step,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n);
|
||||
|
||||
// Get beam search result from prefixes in trie tree
|
||||
std::vector<std::pair<double, std::string>> get_beam_search_result(
|
||||
const std::vector<PathTrie *> &prefixes,
|
||||
const std::vector<std::string> &vocabulary,
|
||||
size_t beam_size);
|
||||
|
||||
// Functor for prefix comparsion
|
||||
bool prefix_compare(const PathTrie *x, const PathTrie *y);
|
||||
|
||||
/* Get length of utf8 encoding string
|
||||
* See: http://stackoverflow.com/a/4063229
|
||||
*/
|
||||
size_t get_utf8_str_len(const std::string &str);
|
||||
|
||||
/* Split a string into a list of strings on a given string
|
||||
* delimiter. NB: delimiters on beginning / end of string are
|
||||
* trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"].
|
||||
*/
|
||||
std::vector<std::string> split_str(const std::string &s,
|
||||
const std::string &delim);
|
||||
|
||||
/* Splits string into vector of strings representing
|
||||
* UTF-8 characters (not same as chars)
|
||||
*/
|
||||
std::vector<std::string> split_utf8_str(const std::string &str);
|
||||
|
||||
// Add a word in index to the dicionary of fst
|
||||
void add_word_to_fst(const std::vector<int> &word,
|
||||
fst::StdVectorFst *dictionary);
|
||||
|
||||
// Add a word in string to dictionary
|
||||
bool add_word_to_dictionary(
|
||||
const std::string &word,
|
||||
const std::unordered_map<std::string, int> &char_map,
|
||||
bool add_space,
|
||||
int SPACE_ID,
|
||||
fst::StdVectorFst *dictionary);
|
||||
#endif // DECODER_UTILS_H
|
@ -0,0 +1,33 @@
|
||||
%module swig_decoders
|
||||
%{
|
||||
#include "scorer.h"
|
||||
#include "ctc_greedy_decoder.h"
|
||||
#include "ctc_beam_search_decoder.h"
|
||||
#include "decoder_utils.h"
|
||||
%}
|
||||
|
||||
%include "std_vector.i"
|
||||
%include "std_pair.i"
|
||||
%include "std_string.i"
|
||||
%import "decoder_utils.h"
|
||||
|
||||
namespace std {
|
||||
%template(DoubleVector) std::vector<double>;
|
||||
%template(IntVector) std::vector<int>;
|
||||
%template(StringVector) std::vector<std::string>;
|
||||
%template(VectorOfStructVector) std::vector<std::vector<double> >;
|
||||
%template(FloatVector) std::vector<float>;
|
||||
%template(Pair) std::pair<float, std::string>;
|
||||
%template(PairFloatStringVector) std::vector<std::pair<float, std::string> >;
|
||||
%template(PairDoubleStringVector) std::vector<std::pair<double, std::string> >;
|
||||
%template(PairDoubleStringVector2) std::vector<std::vector<std::pair<double, std::string> > >;
|
||||
%template(DoubleVector3) std::vector<std::vector<std::vector<double> > >;
|
||||
}
|
||||
|
||||
%template(IntDoublePairCompSecondRev) pair_comp_second_rev<int, double>;
|
||||
%template(StringDoublePairCompSecondRev) pair_comp_second_rev<std::string, double>;
|
||||
%template(DoubleStringPairCompFirstRev) pair_comp_first_rev<double, std::string>;
|
||||
|
||||
%include "scorer.h"
|
||||
%include "ctc_greedy_decoder.h"
|
||||
%include "ctc_beam_search_decoder.h"
|
@ -0,0 +1,148 @@
|
||||
#include "path_trie.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "decoder_utils.h"
|
||||
|
||||
PathTrie::PathTrie() {
|
||||
log_prob_b_prev = -NUM_FLT_INF;
|
||||
log_prob_nb_prev = -NUM_FLT_INF;
|
||||
log_prob_b_cur = -NUM_FLT_INF;
|
||||
log_prob_nb_cur = -NUM_FLT_INF;
|
||||
score = -NUM_FLT_INF;
|
||||
|
||||
ROOT_ = -1;
|
||||
character = ROOT_;
|
||||
exists_ = true;
|
||||
parent = nullptr;
|
||||
|
||||
dictionary_ = nullptr;
|
||||
dictionary_state_ = 0;
|
||||
has_dictionary_ = false;
|
||||
|
||||
matcher_ = nullptr;
|
||||
}
|
||||
|
||||
PathTrie::~PathTrie() {
|
||||
for (auto child : children_) {
|
||||
delete child.second;
|
||||
}
|
||||
}
|
||||
|
||||
PathTrie* PathTrie::get_path_trie(int new_char, bool reset) {
|
||||
auto child = children_.begin();
|
||||
for (child = children_.begin(); child != children_.end(); ++child) {
|
||||
if (child->first == new_char) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (child != children_.end()) {
|
||||
if (!child->second->exists_) {
|
||||
child->second->exists_ = true;
|
||||
child->second->log_prob_b_prev = -NUM_FLT_INF;
|
||||
child->second->log_prob_nb_prev = -NUM_FLT_INF;
|
||||
child->second->log_prob_b_cur = -NUM_FLT_INF;
|
||||
child->second->log_prob_nb_cur = -NUM_FLT_INF;
|
||||
}
|
||||
return (child->second);
|
||||
} else {
|
||||
if (has_dictionary_) {
|
||||
matcher_->SetState(dictionary_state_);
|
||||
bool found = matcher_->Find(new_char);
|
||||
if (!found) {
|
||||
// Adding this character causes word outside dictionary
|
||||
auto FSTZERO = fst::TropicalWeight::Zero();
|
||||
auto final_weight = dictionary_->Final(dictionary_state_);
|
||||
bool is_final = (final_weight != FSTZERO);
|
||||
if (is_final && reset) {
|
||||
dictionary_state_ = dictionary_->Start();
|
||||
}
|
||||
return nullptr;
|
||||
} else {
|
||||
PathTrie* new_path = new PathTrie;
|
||||
new_path->character = new_char;
|
||||
new_path->parent = this;
|
||||
new_path->dictionary_ = dictionary_;
|
||||
new_path->dictionary_state_ = matcher_->Value().nextstate;
|
||||
new_path->has_dictionary_ = true;
|
||||
new_path->matcher_ = matcher_;
|
||||
children_.push_back(std::make_pair(new_char, new_path));
|
||||
return new_path;
|
||||
}
|
||||
} else {
|
||||
PathTrie* new_path = new PathTrie;
|
||||
new_path->character = new_char;
|
||||
new_path->parent = this;
|
||||
children_.push_back(std::make_pair(new_char, new_path));
|
||||
return new_path;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PathTrie* PathTrie::get_path_vec(std::vector<int>& output) {
|
||||
return get_path_vec(output, ROOT_);
|
||||
}
|
||||
|
||||
PathTrie* PathTrie::get_path_vec(std::vector<int>& output,
|
||||
int stop,
|
||||
size_t max_steps) {
|
||||
if (character == stop || character == ROOT_ || output.size() == max_steps) {
|
||||
std::reverse(output.begin(), output.end());
|
||||
return this;
|
||||
} else {
|
||||
output.push_back(character);
|
||||
return parent->get_path_vec(output, stop, max_steps);
|
||||
}
|
||||
}
|
||||
|
||||
void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
|
||||
if (exists_) {
|
||||
log_prob_b_prev = log_prob_b_cur;
|
||||
log_prob_nb_prev = log_prob_nb_cur;
|
||||
|
||||
log_prob_b_cur = -NUM_FLT_INF;
|
||||
log_prob_nb_cur = -NUM_FLT_INF;
|
||||
|
||||
score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
|
||||
output.push_back(this);
|
||||
}
|
||||
for (auto child : children_) {
|
||||
child.second->iterate_to_vec(output);
|
||||
}
|
||||
}
|
||||
|
||||
void PathTrie::remove() {
|
||||
exists_ = false;
|
||||
|
||||
if (children_.size() == 0) {
|
||||
auto child = parent->children_.begin();
|
||||
for (child = parent->children_.begin(); child != parent->children_.end();
|
||||
++child) {
|
||||
if (child->first == character) {
|
||||
parent->children_.erase(child);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (parent->children_.size() == 0 && !parent->exists_) {
|
||||
parent->remove();
|
||||
}
|
||||
|
||||
delete this;
|
||||
}
|
||||
}
|
||||
|
||||
void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) {
|
||||
dictionary_ = dictionary;
|
||||
dictionary_state_ = dictionary->Start();
|
||||
has_dictionary_ = true;
|
||||
}
|
||||
|
||||
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
||||
void PathTrie::set_matcher(std::shared_ptr<FSTMATCH> matcher) {
|
||||
matcher_ = matcher;
|
||||
}
|
@ -0,0 +1,67 @@
|
||||
#ifndef PATH_TRIE_H
|
||||
#define PATH_TRIE_H
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "fst/fstlib.h"
|
||||
|
||||
/* Trie tree for prefix storing and manipulating, with a dictionary in
|
||||
* finite-state transducer for spelling correction.
|
||||
*/
|
||||
class PathTrie {
|
||||
public:
|
||||
PathTrie();
|
||||
~PathTrie();
|
||||
|
||||
// get new prefix after appending new char
|
||||
PathTrie* get_path_trie(int new_char, bool reset = true);
|
||||
|
||||
// get the prefix in index from root to current node
|
||||
PathTrie* get_path_vec(std::vector<int>& output);
|
||||
|
||||
// get the prefix in index from some stop node to current nodel
|
||||
PathTrie* get_path_vec(std::vector<int>& output,
|
||||
int stop,
|
||||
size_t max_steps = std::numeric_limits<size_t>::max());
|
||||
|
||||
// update log probs
|
||||
void iterate_to_vec(std::vector<PathTrie*>& output);
|
||||
|
||||
// set dictionary for FST
|
||||
void set_dictionary(fst::StdVectorFst* dictionary);
|
||||
|
||||
void set_matcher(std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>>);
|
||||
|
||||
bool is_empty() { return ROOT_ == character; }
|
||||
|
||||
// remove current path from root
|
||||
void remove();
|
||||
|
||||
float log_prob_b_prev;
|
||||
float log_prob_nb_prev;
|
||||
float log_prob_b_cur;
|
||||
float log_prob_nb_cur;
|
||||
float score;
|
||||
float approx_ctc;
|
||||
int character;
|
||||
PathTrie* parent;
|
||||
|
||||
private:
|
||||
int ROOT_;
|
||||
bool exists_;
|
||||
bool has_dictionary_;
|
||||
|
||||
std::vector<std::pair<int, PathTrie*>> children_;
|
||||
|
||||
// pointer to dictionary of FST
|
||||
fst::StdVectorFst* dictionary_;
|
||||
fst::StdVectorFst::StateId dictionary_state_;
|
||||
// true if finding ars in FST
|
||||
std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>> matcher_;
|
||||
};
|
||||
|
||||
#endif // PATH_TRIE_H
|
@ -0,0 +1,234 @@
|
||||
#include "scorer.h"
|
||||
|
||||
#include <unistd.h>
|
||||
#include <iostream>
|
||||
|
||||
#include "lm/config.hh"
|
||||
#include "lm/model.hh"
|
||||
#include "lm/state.hh"
|
||||
#include "util/string_piece.hh"
|
||||
#include "util/tokenize_piece.hh"
|
||||
|
||||
#include "decoder_utils.h"
|
||||
|
||||
using namespace lm::ngram;
|
||||
|
||||
Scorer::Scorer(double alpha,
|
||||
double beta,
|
||||
const std::string& lm_path,
|
||||
const std::vector<std::string>& vocab_list) {
|
||||
this->alpha = alpha;
|
||||
this->beta = beta;
|
||||
|
||||
dictionary = nullptr;
|
||||
is_character_based_ = true;
|
||||
language_model_ = nullptr;
|
||||
|
||||
max_order_ = 0;
|
||||
dict_size_ = 0;
|
||||
SPACE_ID_ = -1;
|
||||
|
||||
setup(lm_path, vocab_list);
|
||||
}
|
||||
|
||||
Scorer::~Scorer() {
|
||||
if (language_model_ != nullptr) {
|
||||
delete static_cast<lm::base::Model*>(language_model_);
|
||||
}
|
||||
if (dictionary != nullptr) {
|
||||
delete static_cast<fst::StdVectorFst*>(dictionary);
|
||||
}
|
||||
}
|
||||
|
||||
void Scorer::setup(const std::string& lm_path,
|
||||
const std::vector<std::string>& vocab_list) {
|
||||
// load language model
|
||||
load_lm(lm_path);
|
||||
// set char map for scorer
|
||||
set_char_map(vocab_list);
|
||||
// fill the dictionary for FST
|
||||
if (!is_character_based()) {
|
||||
fill_dictionary(true);
|
||||
}
|
||||
}
|
||||
|
||||
void Scorer::load_lm(const std::string& lm_path) {
|
||||
const char* filename = lm_path.c_str();
|
||||
VALID_CHECK_EQ(access(filename, F_OK), 0, "Invalid language model path");
|
||||
|
||||
RetriveStrEnumerateVocab enumerate;
|
||||
lm::ngram::Config config;
|
||||
config.enumerate_vocab = &enumerate;
|
||||
language_model_ = lm::ngram::LoadVirtual(filename, config);
|
||||
max_order_ = static_cast<lm::base::Model*>(language_model_)->Order();
|
||||
vocabulary_ = enumerate.vocabulary;
|
||||
for (size_t i = 0; i < vocabulary_.size(); ++i) {
|
||||
if (is_character_based_ && vocabulary_[i] != UNK_TOKEN &&
|
||||
vocabulary_[i] != START_TOKEN && vocabulary_[i] != END_TOKEN &&
|
||||
get_utf8_str_len(enumerate.vocabulary[i]) > 1) {
|
||||
is_character_based_ = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
|
||||
lm::base::Model* model = static_cast<lm::base::Model*>(language_model_);
|
||||
double cond_prob;
|
||||
lm::ngram::State state, tmp_state, out_state;
|
||||
// avoid to inserting <s> in begin
|
||||
model->NullContextWrite(&state);
|
||||
for (size_t i = 0; i < words.size(); ++i) {
|
||||
lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]);
|
||||
// encounter OOV
|
||||
if (word_index == 0) {
|
||||
return OOV_SCORE;
|
||||
}
|
||||
cond_prob = model->BaseScore(&state, word_index, &out_state);
|
||||
tmp_state = state;
|
||||
state = out_state;
|
||||
out_state = tmp_state;
|
||||
}
|
||||
// return log10 prob
|
||||
return cond_prob;
|
||||
}
|
||||
|
||||
double Scorer::get_sent_log_prob(const std::vector<std::string>& words) {
|
||||
std::vector<std::string> sentence;
|
||||
if (words.size() == 0) {
|
||||
for (size_t i = 0; i < max_order_; ++i) {
|
||||
sentence.push_back(START_TOKEN);
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < max_order_ - 1; ++i) {
|
||||
sentence.push_back(START_TOKEN);
|
||||
}
|
||||
sentence.insert(sentence.end(), words.begin(), words.end());
|
||||
}
|
||||
sentence.push_back(END_TOKEN);
|
||||
return get_log_prob(sentence);
|
||||
}
|
||||
|
||||
double Scorer::get_log_prob(const std::vector<std::string>& words) {
|
||||
assert(words.size() > max_order_);
|
||||
double score = 0.0;
|
||||
for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) {
|
||||
std::vector<std::string> ngram(words.begin() + i,
|
||||
words.begin() + i + max_order_);
|
||||
score += get_log_cond_prob(ngram);
|
||||
}
|
||||
return score;
|
||||
}
|
||||
|
||||
void Scorer::reset_params(float alpha, float beta) {
|
||||
this->alpha = alpha;
|
||||
this->beta = beta;
|
||||
}
|
||||
|
||||
std::string Scorer::vec2str(const std::vector<int>& input) {
|
||||
std::string word;
|
||||
for (auto ind : input) {
|
||||
word += char_list_[ind];
|
||||
}
|
||||
return word;
|
||||
}
|
||||
|
||||
std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
|
||||
if (labels.empty()) return {};
|
||||
|
||||
std::string s = vec2str(labels);
|
||||
std::vector<std::string> words;
|
||||
if (is_character_based_) {
|
||||
words = split_utf8_str(s);
|
||||
} else {
|
||||
words = split_str(s, " ");
|
||||
}
|
||||
return words;
|
||||
}
|
||||
|
||||
void Scorer::set_char_map(const std::vector<std::string>& char_list) {
|
||||
char_list_ = char_list;
|
||||
char_map_.clear();
|
||||
|
||||
for (size_t i = 0; i < char_list_.size(); i++) {
|
||||
if (char_list_[i] == " ") {
|
||||
SPACE_ID_ = i;
|
||||
char_map_[' '] = i;
|
||||
} else if (char_list_[i].size() == 1) {
|
||||
char_map_[char_list_[i][0]] = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
|
||||
std::vector<std::string> ngram;
|
||||
PathTrie* current_node = prefix;
|
||||
PathTrie* new_node = nullptr;
|
||||
|
||||
for (int order = 0; order < max_order_; order++) {
|
||||
std::vector<int> prefix_vec;
|
||||
|
||||
if (is_character_based_) {
|
||||
new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_, 1);
|
||||
current_node = new_node;
|
||||
} else {
|
||||
new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_);
|
||||
current_node = new_node->parent; // Skipping spaces
|
||||
}
|
||||
|
||||
// reconstruct word
|
||||
std::string word = vec2str(prefix_vec);
|
||||
ngram.push_back(word);
|
||||
|
||||
if (new_node->character == -1) {
|
||||
// No more spaces, but still need order
|
||||
for (int i = 0; i < max_order_ - order - 1; i++) {
|
||||
ngram.push_back(START_TOKEN);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
std::reverse(ngram.begin(), ngram.end());
|
||||
return ngram;
|
||||
}
|
||||
|
||||
void Scorer::fill_dictionary(bool add_space) {
|
||||
fst::StdVectorFst dictionary;
|
||||
// First reverse char_list so ints can be accessed by chars
|
||||
std::unordered_map<std::string, int> char_map;
|
||||
for (size_t i = 0; i < char_list_.size(); i++) {
|
||||
char_map[char_list_[i]] = i;
|
||||
}
|
||||
|
||||
// For each unigram convert to ints and put in trie
|
||||
int dict_size = 0;
|
||||
for (const auto& word : vocabulary_) {
|
||||
bool added = add_word_to_dictionary(
|
||||
word, char_map, add_space, SPACE_ID_, &dictionary);
|
||||
dict_size += added ? 1 : 0;
|
||||
}
|
||||
|
||||
dict_size_ = dict_size;
|
||||
|
||||
/* Simplify FST
|
||||
|
||||
* This gets rid of "epsilon" transitions in the FST.
|
||||
* These are transitions that don't require a string input to be taken.
|
||||
* Getting rid of them is necessary to make the FST determinisitc, but
|
||||
* can greatly increase the size of the FST
|
||||
*/
|
||||
fst::RmEpsilon(&dictionary);
|
||||
fst::StdVectorFst* new_dict = new fst::StdVectorFst;
|
||||
|
||||
/* This makes the FST deterministic, meaning for any string input there's
|
||||
* only one possible state the FST could be in. It is assumed our
|
||||
* dictionary is deterministic when using it.
|
||||
* (lest we'd have to check for multiple transitions at each state)
|
||||
*/
|
||||
fst::Determinize(dictionary, new_dict);
|
||||
|
||||
/* Finds the simplest equivalent fst. This is unnecessary but decreases
|
||||
* memory usage of the dictionary
|
||||
*/
|
||||
fst::Minimize(new_dict);
|
||||
this->dictionary = new_dict;
|
||||
}
|
@ -0,0 +1,112 @@
|
||||
#ifndef SCORER_H_
|
||||
#define SCORER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "lm/enumerate_vocab.hh"
|
||||
#include "lm/virtual_interface.hh"
|
||||
#include "lm/word_index.hh"
|
||||
#include "util/string_piece.hh"
|
||||
|
||||
#include "path_trie.h"
|
||||
|
||||
const double OOV_SCORE = -1000.0;
|
||||
const std::string START_TOKEN = "<s>";
|
||||
const std::string UNK_TOKEN = "<unk>";
|
||||
const std::string END_TOKEN = "</s>";
|
||||
|
||||
// Implement a callback to retrive the dictionary of language model.
|
||||
class RetriveStrEnumerateVocab : public lm::EnumerateVocab {
|
||||
public:
|
||||
RetriveStrEnumerateVocab() {}
|
||||
|
||||
void Add(lm::WordIndex index, const StringPiece &str) {
|
||||
vocabulary.push_back(std::string(str.data(), str.length()));
|
||||
}
|
||||
|
||||
std::vector<std::string> vocabulary;
|
||||
};
|
||||
|
||||
/* External scorer to query score for n-gram or sentence, including language
|
||||
* model scoring and word insertion.
|
||||
*
|
||||
* Example:
|
||||
* Scorer scorer(alpha, beta, "path_of_language_model");
|
||||
* scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" });
|
||||
* scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
|
||||
*/
|
||||
class Scorer {
|
||||
public:
|
||||
Scorer(double alpha,
|
||||
double beta,
|
||||
const std::string &lm_path,
|
||||
const std::vector<std::string> &vocabulary);
|
||||
~Scorer();
|
||||
|
||||
double get_log_cond_prob(const std::vector<std::string> &words);
|
||||
|
||||
double get_sent_log_prob(const std::vector<std::string> &words);
|
||||
|
||||
// return the max order
|
||||
size_t get_max_order() const { return max_order_; }
|
||||
|
||||
// return the dictionary size of language model
|
||||
size_t get_dict_size() const { return dict_size_; }
|
||||
|
||||
// retrun true if the language model is character based
|
||||
bool is_character_based() const { return is_character_based_; }
|
||||
|
||||
// reset params alpha & beta
|
||||
void reset_params(float alpha, float beta);
|
||||
|
||||
// make ngram for a given prefix
|
||||
std::vector<std::string> make_ngram(PathTrie *prefix);
|
||||
|
||||
// trransform the labels in index to the vector of words (word based lm) or
|
||||
// the vector of characters (character based lm)
|
||||
std::vector<std::string> split_labels(const std::vector<int> &labels);
|
||||
|
||||
// language model weight
|
||||
double alpha;
|
||||
// word insertion weight
|
||||
double beta;
|
||||
|
||||
// pointer to the dictionary of FST
|
||||
void *dictionary;
|
||||
|
||||
protected:
|
||||
// necessary setup: load language model, set char map, fill FST's dictionary
|
||||
void setup(const std::string &lm_path,
|
||||
const std::vector<std::string> &vocab_list);
|
||||
|
||||
// load language model from given path
|
||||
void load_lm(const std::string &lm_path);
|
||||
|
||||
// fill dictionary for FST
|
||||
void fill_dictionary(bool add_space);
|
||||
|
||||
// set char map
|
||||
void set_char_map(const std::vector<std::string> &char_list);
|
||||
|
||||
double get_log_prob(const std::vector<std::string> &words);
|
||||
|
||||
// translate the vector in index to string
|
||||
std::string vec2str(const std::vector<int> &input);
|
||||
|
||||
private:
|
||||
void *language_model_;
|
||||
bool is_character_based_;
|
||||
size_t max_order_;
|
||||
size_t dict_size_;
|
||||
|
||||
int SPACE_ID_;
|
||||
std::vector<std::string> char_list_;
|
||||
std::unordered_map<char, int> char_map_;
|
||||
|
||||
std::vector<std::string> vocabulary_;
|
||||
};
|
||||
|
||||
#endif // SCORER_H_
|
@ -0,0 +1,121 @@
|
||||
"""Script to build and install decoder package."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from setuptools import setup, Extension, distutils
|
||||
import glob
|
||||
import platform
|
||||
import os, sys
|
||||
import multiprocessing.pool
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--num_processes",
|
||||
default=1,
|
||||
type=int,
|
||||
help="Number of cpu processes to build package. (default: %(default)d)")
|
||||
args = parser.parse_known_args()
|
||||
|
||||
# reconstruct sys.argv to pass to setup below
|
||||
sys.argv = [sys.argv[0]] + args[1]
|
||||
|
||||
|
||||
# monkey-patch for parallel compilation
|
||||
# See: https://stackoverflow.com/a/13176803
|
||||
def parallelCCompile(self,
|
||||
sources,
|
||||
output_dir=None,
|
||||
macros=None,
|
||||
include_dirs=None,
|
||||
debug=0,
|
||||
extra_preargs=None,
|
||||
extra_postargs=None,
|
||||
depends=None):
|
||||
# those lines are copied from distutils.ccompiler.CCompiler directly
|
||||
macros, objects, extra_postargs, pp_opts, build = self._setup_compile(
|
||||
output_dir, macros, include_dirs, sources, depends, extra_postargs)
|
||||
cc_args = self._get_cc_args(pp_opts, debug, extra_preargs)
|
||||
|
||||
# parallel code
|
||||
def _single_compile(obj):
|
||||
try:
|
||||
src, ext = build[obj]
|
||||
except KeyError:
|
||||
return
|
||||
self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts)
|
||||
|
||||
# convert to list, imap is evaluated on-demand
|
||||
thread_pool = multiprocessing.pool.ThreadPool(args[0].num_processes)
|
||||
list(thread_pool.imap(_single_compile, objects))
|
||||
return objects
|
||||
|
||||
|
||||
def compile_test(header, library):
|
||||
dummy_path = os.path.join(os.path.dirname(__file__), "dummy")
|
||||
command = "bash -c \"g++ -include " + header \
|
||||
+ " -l" + library + " -x c++ - <<<'int main() {}' -o " \
|
||||
+ dummy_path + " >/dev/null 2>/dev/null && rm " \
|
||||
+ dummy_path + " 2>/dev/null\""
|
||||
return os.system(command) == 0
|
||||
|
||||
|
||||
# hack compile to support parallel compiling
|
||||
distutils.ccompiler.CCompiler.compile = parallelCCompile
|
||||
|
||||
FILES = glob.glob('kenlm/util/*.cc') \
|
||||
+ glob.glob('kenlm/lm/*.cc') \
|
||||
+ glob.glob('kenlm/util/double-conversion/*.cc')
|
||||
|
||||
FILES += glob.glob('openfst-1.6.3/src/lib/*.cc')
|
||||
|
||||
# FILES + glob.glob('glog/src/*.cc')
|
||||
FILES = [
|
||||
fn for fn in FILES
|
||||
if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith(
|
||||
'unittest.cc'))
|
||||
]
|
||||
|
||||
LIBS = ['stdc++']
|
||||
if platform.system() != 'Darwin':
|
||||
LIBS.append('rt')
|
||||
|
||||
ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER=6', '-std=c++11']
|
||||
|
||||
if compile_test('zlib.h', 'z'):
|
||||
ARGS.append('-DHAVE_ZLIB')
|
||||
LIBS.append('z')
|
||||
|
||||
if compile_test('bzlib.h', 'bz2'):
|
||||
ARGS.append('-DHAVE_BZLIB')
|
||||
LIBS.append('bz2')
|
||||
|
||||
if compile_test('lzma.h', 'lzma'):
|
||||
ARGS.append('-DHAVE_XZLIB')
|
||||
LIBS.append('lzma')
|
||||
|
||||
os.system('swig -python -c++ ./decoders.i')
|
||||
|
||||
decoders_module = [
|
||||
Extension(
|
||||
name='_swig_decoders',
|
||||
sources=FILES + glob.glob('*.cxx') + glob.glob('*.cpp'),
|
||||
language='c++',
|
||||
include_dirs=[
|
||||
'.',
|
||||
'kenlm',
|
||||
'openfst-1.6.3/src/include',
|
||||
'ThreadPool',
|
||||
#'glog/src'
|
||||
],
|
||||
libraries=LIBS,
|
||||
extra_compile_args=ARGS)
|
||||
]
|
||||
|
||||
setup(
|
||||
name='swig_decoders',
|
||||
version='0.1',
|
||||
description="""CTC decoders""",
|
||||
ext_modules=decoders_module,
|
||||
py_modules=['swig_decoders'], )
|
@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
if [ ! -d kenlm ]; then
|
||||
git clone https://github.com/luotao1/kenlm.git
|
||||
echo -e "\n"
|
||||
fi
|
||||
|
||||
if [ ! -d openfst-1.6.3 ]; then
|
||||
echo "Download and extract openfst ..."
|
||||
wget http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz
|
||||
tar -xzvf openfst-1.6.3.tar.gz
|
||||
echo -e "\n"
|
||||
fi
|
||||
|
||||
if [ ! -d ThreadPool ]; then
|
||||
git clone https://github.com/progschj/ThreadPool.git
|
||||
echo -e "\n"
|
||||
fi
|
||||
|
||||
echo "Install decoders ..."
|
||||
python setup.py install --num_processes 4
|
@ -0,0 +1,116 @@
|
||||
"""Wrapper for various CTC decoders in SWIG."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import swig_decoders
|
||||
|
||||
|
||||
class Scorer(swig_decoders.Scorer):
|
||||
"""Wrapper for Scorer.
|
||||
|
||||
:param alpha: Parameter associated with language model. Don't use
|
||||
language model when alpha = 0.
|
||||
:type alpha: float
|
||||
:param beta: Parameter associated with word count. Don't use word
|
||||
count when beta = 0.
|
||||
:type beta: float
|
||||
:model_path: Path to load language model.
|
||||
:type model_path: basestring
|
||||
"""
|
||||
|
||||
def __init__(self, alpha, beta, model_path, vocabulary):
|
||||
swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary)
|
||||
|
||||
|
||||
def ctc_greedy_decoder(probs_seq, vocabulary):
|
||||
"""Wrapper for ctc best path decoder in swig.
|
||||
|
||||
:param probs_seq: 2-D list of probability distributions over each time
|
||||
step, with each element being a list of normalized
|
||||
probabilities over vocabulary and blank.
|
||||
:type probs_seq: 2-D list
|
||||
:param vocabulary: Vocabulary list.
|
||||
:type vocabulary: list
|
||||
:return: Decoding result string.
|
||||
:rtype: basestring
|
||||
"""
|
||||
return swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary)
|
||||
|
||||
|
||||
def ctc_beam_search_decoder(probs_seq,
|
||||
vocabulary,
|
||||
beam_size,
|
||||
cutoff_prob=1.0,
|
||||
cutoff_top_n=40,
|
||||
ext_scoring_func=None):
|
||||
"""Wrapper for the CTC Beam Search Decoder.
|
||||
|
||||
:param probs_seq: 2-D list of probability distributions over each time
|
||||
step, with each element being a list of normalized
|
||||
probabilities over vocabulary and blank.
|
||||
:type probs_seq: 2-D list
|
||||
:param vocabulary: Vocabulary list.
|
||||
:type vocabulary: list
|
||||
:param beam_size: Width for beam search.
|
||||
:type beam_size: int
|
||||
:param cutoff_prob: Cutoff probability in pruning,
|
||||
default 1.0, no pruning.
|
||||
:type cutoff_prob: float
|
||||
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
|
||||
characters with highest probs in vocabulary will be
|
||||
used in beam search, default 40.
|
||||
:type cutoff_top_n: int
|
||||
:param ext_scoring_func: External scoring function for
|
||||
partially decoded sentence, e.g. word count
|
||||
or language model.
|
||||
:type external_scoring_func: callable
|
||||
:return: List of tuples of log probability and sentence as decoding
|
||||
results, in descending order of the probability.
|
||||
:rtype: list
|
||||
"""
|
||||
return swig_decoders.ctc_beam_search_decoder(probs_seq.tolist(), vocabulary,
|
||||
beam_size, cutoff_prob,
|
||||
cutoff_top_n, ext_scoring_func)
|
||||
|
||||
|
||||
def ctc_beam_search_decoder_batch(probs_split,
|
||||
vocabulary,
|
||||
beam_size,
|
||||
num_processes,
|
||||
cutoff_prob=1.0,
|
||||
cutoff_top_n=40,
|
||||
ext_scoring_func=None):
|
||||
"""Wrapper for the batched CTC beam search decoder.
|
||||
|
||||
:param probs_seq: 3-D list with each element as an instance of 2-D list
|
||||
of probabilities used by ctc_beam_search_decoder().
|
||||
:type probs_seq: 3-D list
|
||||
:param vocabulary: Vocabulary list.
|
||||
:type vocabulary: list
|
||||
:param beam_size: Width for beam search.
|
||||
:type beam_size: int
|
||||
:param num_processes: Number of parallel processes.
|
||||
:type num_processes: int
|
||||
:param cutoff_prob: Cutoff probability in vocabulary pruning,
|
||||
default 1.0, no pruning.
|
||||
:type cutoff_prob: float
|
||||
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
|
||||
characters with highest probs in vocabulary will be
|
||||
used in beam search, default 40.
|
||||
:type cutoff_top_n: int
|
||||
:param num_processes: Number of parallel processes.
|
||||
:type num_processes: int
|
||||
:param ext_scoring_func: External scoring function for
|
||||
partially decoded sentence, e.g. word count
|
||||
or language model.
|
||||
:type external_scoring_function: callable
|
||||
:return: List of tuples of log probability and sentence as decoding
|
||||
results, in descending order of the probability.
|
||||
:rtype: list
|
||||
"""
|
||||
probs_split = [probs_seq.tolist() for probs_seq in probs_split]
|
||||
|
||||
return swig_decoders.ctc_beam_search_decoder_batch(
|
||||
probs_split, vocabulary, beam_size, num_processes, cutoff_prob,
|
||||
cutoff_top_n, ext_scoring_func)
|
After Width: | Height: | Size: 153 KiB |
After Width: | Height: | Size: 108 KiB |
@ -1,169 +0,0 @@
|
||||
"""Evaluation for DeepSpeech2 model."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import distutils.util
|
||||
import argparse
|
||||
import multiprocessing
|
||||
import paddle.v2 as paddle
|
||||
from data_utils.data import DataGenerator
|
||||
from model import DeepSpeech2Model
|
||||
from error_rate import wer
|
||||
import utils
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
default=128,
|
||||
type=int,
|
||||
help="Minibatch size for evaluation. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--trainer_count",
|
||||
default=8,
|
||||
type=int,
|
||||
help="Trainer number. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--num_conv_layers",
|
||||
default=2,
|
||||
type=int,
|
||||
help="Convolution layer number. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--num_rnn_layers",
|
||||
default=3,
|
||||
type=int,
|
||||
help="RNN layer number. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--rnn_layer_size",
|
||||
default=512,
|
||||
type=int,
|
||||
help="RNN layer cell number. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--use_gpu",
|
||||
default=True,
|
||||
type=distutils.util.strtobool,
|
||||
help="Use gpu or not. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--num_threads_data",
|
||||
default=multiprocessing.cpu_count() // 2,
|
||||
type=int,
|
||||
help="Number of cpu threads for preprocessing data. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--num_processes_beam_search",
|
||||
default=multiprocessing.cpu_count() // 2,
|
||||
type=int,
|
||||
help="Number of cpu processes for beam search. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--mean_std_filepath",
|
||||
default='mean_std.npz',
|
||||
type=str,
|
||||
help="Manifest path for normalizer. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--decode_method",
|
||||
default='beam_search',
|
||||
type=str,
|
||||
help="Method for ctc decoding, best_path or beam_search. "
|
||||
"(default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--language_model_path",
|
||||
default="lm/data/common_crawl_00.prune01111.trie.klm",
|
||||
type=str,
|
||||
help="Path for language model. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--alpha",
|
||||
default=0.36,
|
||||
type=float,
|
||||
help="Parameter associated with language model. (default: %(default)f)")
|
||||
parser.add_argument(
|
||||
"--beta",
|
||||
default=0.25,
|
||||
type=float,
|
||||
help="Parameter associated with word count. (default: %(default)f)")
|
||||
parser.add_argument(
|
||||
"--cutoff_prob",
|
||||
default=0.99,
|
||||
type=float,
|
||||
help="The cutoff probability of pruning"
|
||||
"in beam search. (default: %(default)f)")
|
||||
parser.add_argument(
|
||||
"--beam_size",
|
||||
default=500,
|
||||
type=int,
|
||||
help="Width for beam search decoding. (default: %(default)d)")
|
||||
parser.add_argument(
|
||||
"--specgram_type",
|
||||
default='linear',
|
||||
type=str,
|
||||
help="Feature type of audio data: 'linear' (power spectrum)"
|
||||
" or 'mfcc'. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--decode_manifest_path",
|
||||
default='datasets/manifest.test',
|
||||
type=str,
|
||||
help="Manifest path for decoding. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--model_filepath",
|
||||
default='checkpoints/params.latest.tar.gz',
|
||||
type=str,
|
||||
help="Model filepath. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--vocab_filepath",
|
||||
default='datasets/vocab/eng_vocab.txt',
|
||||
type=str,
|
||||
help="Vocabulary filepath. (default: %(default)s)")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def evaluate():
|
||||
"""Evaluate on whole test data for DeepSpeech2."""
|
||||
data_generator = DataGenerator(
|
||||
vocab_filepath=args.vocab_filepath,
|
||||
mean_std_filepath=args.mean_std_filepath,
|
||||
augmentation_config='{}',
|
||||
specgram_type=args.specgram_type,
|
||||
num_threads=args.num_threads_data)
|
||||
batch_reader = data_generator.batch_reader_creator(
|
||||
manifest_path=args.decode_manifest_path,
|
||||
batch_size=args.batch_size,
|
||||
min_batch_size=1,
|
||||
sortagrad=False,
|
||||
shuffle_method=None)
|
||||
|
||||
ds2_model = DeepSpeech2Model(
|
||||
vocab_size=data_generator.vocab_size,
|
||||
num_conv_layers=args.num_conv_layers,
|
||||
num_rnn_layers=args.num_rnn_layers,
|
||||
rnn_layer_size=args.rnn_layer_size,
|
||||
pretrained_model_path=args.model_filepath)
|
||||
|
||||
wer_sum, num_ins = 0.0, 0
|
||||
for infer_data in batch_reader():
|
||||
result_transcripts = ds2_model.infer_batch(
|
||||
infer_data=infer_data,
|
||||
decode_method=args.decode_method,
|
||||
beam_alpha=args.alpha,
|
||||
beam_beta=args.beta,
|
||||
beam_size=args.beam_size,
|
||||
cutoff_prob=args.cutoff_prob,
|
||||
vocab_list=data_generator.vocab_list,
|
||||
language_model_path=args.language_model_path,
|
||||
num_processes=args.num_processes_beam_search)
|
||||
target_transcripts = [
|
||||
''.join([data_generator.vocab_list[token] for token in transcript])
|
||||
for _, transcript in infer_data
|
||||
]
|
||||
for target, result in zip(target_transcripts, result_transcripts):
|
||||
wer_sum += wer(target, result)
|
||||
num_ins += 1
|
||||
print("WER (%d/?) = %f" % (num_ins, wer_sum / num_ins))
|
||||
print("Final WER (%d/%d) = %f" % (num_ins, num_ins, wer_sum / num_ins))
|
||||
|
||||
|
||||
def main():
|
||||
utils.print_arguments(args)
|
||||
paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
|
||||
evaluate()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,42 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
pushd ../.. > /dev/null
|
||||
|
||||
# download data, generate manifests
|
||||
PYTHONPATH=.:$PYTHONPATH python data/aishell/aishell.py \
|
||||
--manifest_prefix='data/aishell/manifest' \
|
||||
--target_dir='~/.cache/paddle/dataset/speech/Aishell'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Prepare Aishell failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
# build vocabulary
|
||||
python tools/build_vocab.py \
|
||||
--count_threshold=0 \
|
||||
--vocab_path='data/aishell/vocab.txt' \
|
||||
--manifest_paths='data/aishell/manifest.train'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Build vocabulary failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
# compute mean and stddev for normalizer
|
||||
python tools/compute_mean_std.py \
|
||||
--manifest_path='data/aishell/manifest.train' \
|
||||
--num_samples=2000 \
|
||||
--specgram_type='linear' \
|
||||
--output_path='data/aishell/mean_std.npz'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Compute mean and stddev failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
echo "Aishell data preparation done."
|
||||
exit 0
|
@ -0,0 +1,41 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
pushd ../.. > /dev/null
|
||||
|
||||
# train model
|
||||
# if you wish to resume from an exists model, uncomment --init_model_path
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
python -u train.py \
|
||||
--batch_size=64 \
|
||||
--trainer_count=8 \
|
||||
--num_passes=50 \
|
||||
--num_proc_data=12 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=1024 \
|
||||
--num_iter_print=100 \
|
||||
--learning_rate=5e-4 \
|
||||
--max_duration=27.0 \
|
||||
--min_duration=0.0 \
|
||||
--test_off=False \
|
||||
--use_sortagrad=True \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--is_local=True \
|
||||
--share_rnn_weights=False \
|
||||
--train_manifest='data/aishell/manifest.train' \
|
||||
--dev_manifest='data/aishell/manifest.dev' \
|
||||
--mean_std_path='data/aishell/mean_std.npz' \
|
||||
--vocab_path='data/aishell/vocab.txt' \
|
||||
--output_model_dir='./checkpoints/aishell' \
|
||||
--augment_conf_path='conf/augmentation.config' \
|
||||
--specgram_type='linear' \
|
||||
--shuffle_method='batch_shuffle_clipped'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in training!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,45 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
pushd ../.. > /dev/null
|
||||
|
||||
# download data, generate manifests
|
||||
PYTHONPATH=.:$PYPYTHONPATH python data/librispeech/librispeech.py \
|
||||
--manifest_prefix='data/librispeech/manifest' \
|
||||
--target_dir='~/.cache/paddle/dataset/speech/Libri' \
|
||||
--full_download='True'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Prepare LibriSpeech failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cat data/librispeech/manifest.train-* | shuf > data/librispeech/manifest.train
|
||||
|
||||
|
||||
# build vocabulary
|
||||
python tools/build_vocab.py \
|
||||
--count_threshold=0 \
|
||||
--vocab_path='data/librispeech/vocab.txt' \
|
||||
--manifest_paths='data/librispeech/manifest.train'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Build vocabulary failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
# compute mean and stddev for normalizer
|
||||
python tools/compute_mean_std.py \
|
||||
--manifest_path='data/librispeech/manifest.train' \
|
||||
--num_samples=2000 \
|
||||
--specgram_type='linear' \
|
||||
--output_path='data/librispeech/mean_std.npz'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Compute mean and stddev failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
echo "LibriSpeech Data preparation done."
|
||||
exit 0
|
@ -0,0 +1,46 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
pushd ../.. > /dev/null
|
||||
|
||||
# download language model
|
||||
pushd models/lm > /dev/null
|
||||
sh download_lm_en.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
popd > /dev/null
|
||||
|
||||
|
||||
# infer
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python -u infer.py \
|
||||
--num_samples=10 \
|
||||
--trainer_count=1 \
|
||||
--beam_size=500 \
|
||||
--num_proc_bsearch=8 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--alpha=2.15 \
|
||||
--beta=0.35 \
|
||||
--cutoff_prob=1.0 \
|
||||
--cutoff_top_n=40 \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=True \
|
||||
--infer_manifest='data/librispeech/manifest.test-clean' \
|
||||
--mean_std_path='data/librispeech/mean_std.npz' \
|
||||
--vocab_path='data/librispeech/vocab.txt' \
|
||||
--model_path='checkpoints/libri/params.latest.tar.gz' \
|
||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||
--decoding_method='ctc_beam_search' \
|
||||
--error_rate_type='wer' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in inference!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,55 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
pushd ../.. > /dev/null
|
||||
|
||||
# download language model
|
||||
pushd models/lm > /dev/null
|
||||
sh download_lm_en.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
popd > /dev/null
|
||||
|
||||
|
||||
# download well-trained model
|
||||
pushd models/librispeech > /dev/null
|
||||
sh download_model.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
popd > /dev/null
|
||||
|
||||
|
||||
# infer
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python -u infer.py \
|
||||
--num_samples=10 \
|
||||
--trainer_count=1 \
|
||||
--beam_size=500 \
|
||||
--num_proc_bsearch=8 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--alpha=2.15 \
|
||||
--beta=0.35 \
|
||||
--cutoff_prob=1.0 \
|
||||
--cutoff_top_n=40 \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=True \
|
||||
--infer_manifest='data/librispeech/manifest.test-clean' \
|
||||
--mean_std_path='models/librispeech/mean_std.npz' \
|
||||
--vocab_path='models/librispeech/vocab.txt' \
|
||||
--model_path='models/librispeech/params.tar.gz' \
|
||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||
--decoding_method='ctc_beam_search' \
|
||||
--error_rate_type='wer' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in inference!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,46 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
pushd ../.. > /dev/null
|
||||
|
||||
# download language model
|
||||
pushd models/lm > /dev/null
|
||||
sh download_lm_en.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
popd > /dev/null
|
||||
|
||||
|
||||
# evaluate model
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
python -u test.py \
|
||||
--batch_size=128 \
|
||||
--trainer_count=8 \
|
||||
--beam_size=500 \
|
||||
--num_proc_bsearch=8 \
|
||||
--num_proc_data=4 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--alpha=2.15 \
|
||||
--beta=0.35 \
|
||||
--cutoff_prob=1.0 \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=True \
|
||||
--test_manifest='data/librispeech/manifest.test-clean' \
|
||||
--mean_std_path='data/librispeech/mean_std.npz' \
|
||||
--vocab_path='data/librispeech/vocab.txt' \
|
||||
--model_path='checkpoints/libri/params.latest.tar.gz' \
|
||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||
--decoding_method='ctc_beam_search' \
|
||||
--error_rate_type='wer' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,56 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
pushd ../.. > /dev/null
|
||||
|
||||
# download language model
|
||||
pushd models/lm > /dev/null
|
||||
sh download_lm_en.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
popd > /dev/null
|
||||
|
||||
|
||||
# download well-trained model
|
||||
pushd models/librispeech > /dev/null
|
||||
sh download_model.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
popd > /dev/null
|
||||
|
||||
|
||||
# evaluate model
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
python -u test.py \
|
||||
--batch_size=128 \
|
||||
--trainer_count=8 \
|
||||
--beam_size=500 \
|
||||
--num_proc_bsearch=8 \
|
||||
--num_proc_data=4 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--alpha=2.15 \
|
||||
--beta=0.35 \
|
||||
--cutoff_prob=1.0 \
|
||||
--cutoff_top_n=40 \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=True \
|
||||
--test_manifest='data/librispeech/manifest.test-clean' \
|
||||
--mean_std_path='models/librispeech/mean_std.npz' \
|
||||
--vocab_path='models/librispeech/vocab.txt' \
|
||||
--model_path='models/librispeech/params.tar.gz' \
|
||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||
--decoding_method='ctc_beam_search' \
|
||||
--error_rate_type='wer' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,41 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
pushd ../.. > /dev/null
|
||||
|
||||
# train model
|
||||
# if you wish to resume from an exists model, uncomment --init_model_path
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
python -u train.py \
|
||||
--batch_size=512 \
|
||||
--trainer_count=8 \
|
||||
--num_passes=50 \
|
||||
--num_proc_data=12 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--num_iter_print=100 \
|
||||
--learning_rate=5e-4 \
|
||||
--max_duration=27.0 \
|
||||
--min_duration=0.0 \
|
||||
--test_off=False \
|
||||
--use_sortagrad=True \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--is_local=True \
|
||||
--share_rnn_weights=True \
|
||||
--train_manifest='data/librispeech/manifest.train' \
|
||||
--dev_manifest='data/librispeech/manifest.dev' \
|
||||
--mean_std_path='data/librispeech/mean_std.npz' \
|
||||
--vocab_path='data/librispeech/vocab.txt' \
|
||||
--output_model_dir='./checkpoints/libri' \
|
||||
--augment_conf_path='conf/augmentation.config' \
|
||||
--specgram_type='linear' \
|
||||
--shuffle_method='batch_shuffle_clipped'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in training!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,39 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
pushd ../.. > /dev/null
|
||||
|
||||
# grid-search for hyper-parameters in language model
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
python -u tools/tune.py \
|
||||
--num_samples=100 \
|
||||
--trainer_count=8 \
|
||||
--beam_size=500 \
|
||||
--num_proc_bsearch=12 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--num_alphas=14 \
|
||||
--num_betas=20 \
|
||||
--alpha_from=0.1 \
|
||||
--alpha_to=0.36 \
|
||||
--beta_from=0.05 \
|
||||
--beta_to=1.0 \
|
||||
--cutoff_prob=0.99 \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=True \
|
||||
--tune_manifest='data/librispeech/manifest.dev-clean' \
|
||||
--mean_std_path='data/librispeech/mean_std.npz' \
|
||||
--vocab_path='data/librispeech/vocab.txt' \
|
||||
--model_path='checkpoints/libri/params.latest.tar.gz' \
|
||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||
--error_rate_type='wer' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in tuning!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,17 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
pushd ../.. > /dev/null
|
||||
|
||||
# start demo client
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python -u deploy/demo_client.py \
|
||||
--host_ip='localhost' \
|
||||
--host_port=8086 \
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in starting demo client!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,53 @@
|
||||
#! /usr/bin/env bash
|
||||
# TODO: replace the model with a mandarin model
|
||||
|
||||
pushd ../.. > /dev/null
|
||||
|
||||
# download language model
|
||||
pushd models/lm > /dev/null
|
||||
sh download_lm_en.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
popd > /dev/null
|
||||
|
||||
|
||||
# download well-trained model
|
||||
pushd models/librispeech > /dev/null
|
||||
sh download_model.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
popd > /dev/null
|
||||
|
||||
|
||||
# start demo server
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python -u deploy/demo_server.py \
|
||||
--host_ip='localhost' \
|
||||
--host_port=8086 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--alpha=0.36 \
|
||||
--beta=0.25 \
|
||||
--cutoff_prob=0.99 \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=True \
|
||||
--speech_save_dir='demo_cache' \
|
||||
--warmup_manifest='data/tiny/manifest.test-clean' \
|
||||
--mean_std_path='models/librispeech/mean_std.npz' \
|
||||
--vocab_path='models/librispeech/vocab.txt' \
|
||||
--model_path='models/librispeech/params.tar.gz' \
|
||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||
--decoding_method='ctc_beam_search' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in starting demo server!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,51 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
pushd ../.. > /dev/null
|
||||
|
||||
# prepare folder
|
||||
if [ ! -e data/tiny ]; then
|
||||
mkdir data/tiny
|
||||
fi
|
||||
|
||||
|
||||
# download data, generate manifests
|
||||
python data/librispeech/librispeech.py \
|
||||
--manifest_prefix='data/tiny/manifest' \
|
||||
--target_dir='~/.cache/paddle/dataset/speech/libri' \
|
||||
--full_download='False'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Prepare LibriSpeech failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
head -n 64 data/tiny/manifest.dev-clean > data/tiny/manifest.tiny
|
||||
|
||||
|
||||
# build vocabulary
|
||||
python tools/build_vocab.py \
|
||||
--count_threshold=0 \
|
||||
--vocab_path='data/tiny/vocab.txt' \
|
||||
--manifest_paths='data/tiny/manifest.dev'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Build vocabulary failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
# compute mean and stddev for normalizer
|
||||
python tools/compute_mean_std.py \
|
||||
--manifest_path='data/tiny/manifest.tiny' \
|
||||
--num_samples=64 \
|
||||
--specgram_type='linear' \
|
||||
--output_path='data/tiny/mean_std.npz'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Compute mean and stddev failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
echo "Tiny data preparation done."
|
||||
exit 0
|
@ -0,0 +1,45 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
pushd ../.. > /dev/null
|
||||
|
||||
# download language model
|
||||
pushd models/lm > /dev/null
|
||||
sh download_lm_en.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
popd > /dev/null
|
||||
|
||||
|
||||
# infer
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python -u infer.py \
|
||||
--num_samples=10 \
|
||||
--trainer_count=1 \
|
||||
--beam_size=500 \
|
||||
--num_proc_bsearch=8 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--alpha=2.15 \
|
||||
--beta=0.35 \
|
||||
--cutoff_prob=1.0 \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=True \
|
||||
--infer_manifest='data/tiny/manifest.tiny' \
|
||||
--mean_std_path='data/tiny/mean_std.npz' \
|
||||
--vocab_path='data/tiny/vocab.txt' \
|
||||
--model_path='checkpoints/tiny/params.pass-19.tar.gz' \
|
||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||
--decoding_method='ctc_beam_search' \
|
||||
--error_rate_type='wer' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in inference!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,54 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
pushd ../.. > /dev/null
|
||||
|
||||
# download language model
|
||||
pushd models/lm > /dev/null
|
||||
sh download_lm_en.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
popd > /dev/null
|
||||
|
||||
|
||||
# download well-trained model
|
||||
pushd models/librispeech > /dev/null
|
||||
sh download_model.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
popd > /dev/null
|
||||
|
||||
|
||||
# infer
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python -u infer.py \
|
||||
--num_samples=10 \
|
||||
--trainer_count=1 \
|
||||
--beam_size=500 \
|
||||
--num_proc_bsearch=8 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--alpha=2.15 \
|
||||
--beta=0.35 \
|
||||
--cutoff_prob=1.0 \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=True \
|
||||
--infer_manifest='data/tiny/manifest.test-clean' \
|
||||
--mean_std_path='models/librispeech/mean_std.npz' \
|
||||
--vocab_path='models/librispeech/vocab.txt' \
|
||||
--model_path='models/librispeech/params.tar.gz' \
|
||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||
--decoding_method='ctc_beam_search' \
|
||||
--error_rate_type='wer' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in inference!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,46 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
pushd ../.. > /dev/null
|
||||
|
||||
# download language model
|
||||
pushd models/lm > /dev/null
|
||||
sh download_lm_en.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
popd > /dev/null
|
||||
|
||||
|
||||
# evaluate model
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
python -u test.py \
|
||||
--batch_size=16 \
|
||||
--trainer_count=8 \
|
||||
--beam_size=500 \
|
||||
--num_proc_bsearch=8 \
|
||||
--num_proc_data=4 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--alpha=2.15 \
|
||||
--beta=0.35 \
|
||||
--cutoff_prob=1.0 \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=True \
|
||||
--test_manifest='data/tiny/manifest.tiny' \
|
||||
--mean_std_path='data/tiny/mean_std.npz' \
|
||||
--vocab_path='data/tiny/vocab.txt' \
|
||||
--model_path='checkpoints/params.pass-19.tar.gz' \
|
||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||
--decoding_method='ctc_beam_search' \
|
||||
--error_rate_type='wer' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,55 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
pushd ../.. > /dev/null
|
||||
|
||||
# download language model
|
||||
pushd models/lm > /dev/null
|
||||
sh download_lm_en.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
popd > /dev/null
|
||||
|
||||
|
||||
# download well-trained model
|
||||
pushd models/librispeech > /dev/null
|
||||
sh download_model.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
popd > /dev/null
|
||||
|
||||
|
||||
# evaluate model
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
python -u test.py \
|
||||
--batch_size=128 \
|
||||
--trainer_count=8 \
|
||||
--beam_size=500 \
|
||||
--num_proc_bsearch=8 \
|
||||
--num_proc_data=4 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--alpha=2.15 \
|
||||
--beta=0.35 \
|
||||
--cutoff_prob=1.0 \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=True \
|
||||
--test_manifest='data/tiny/manifest.test-clean' \
|
||||
--mean_std_path='models/librispeech/mean_std.npz' \
|
||||
--vocab_path='models/librispeech/vocab.txt' \
|
||||
--model_path='models/librispeech/params.tar.gz' \
|
||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||
--decoding_method='ctc_beam_search' \
|
||||
--error_rate_type='wer' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,41 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
pushd ../.. > /dev/null
|
||||
|
||||
# train model
|
||||
# if you wish to resume from an exists model, uncomment --init_model_path
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||
python -u train.py \
|
||||
--batch_size=16 \
|
||||
--trainer_count=4 \
|
||||
--num_passes=20 \
|
||||
--num_proc_data=1 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--num_iter_print=100 \
|
||||
--learning_rate=1e-5 \
|
||||
--max_duration=27.0 \
|
||||
--min_duration=0.0 \
|
||||
--test_off=False \
|
||||
--use_sortagrad=True \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--is_local=True \
|
||||
--share_rnn_weights=True \
|
||||
--train_manifest='data/tiny/manifest.tiny' \
|
||||
--dev_manifest='data/tiny/manifest.tiny' \
|
||||
--mean_std_path='data/tiny/mean_std.npz' \
|
||||
--vocab_path='data/tiny/vocab.txt' \
|
||||
--output_model_dir='./checkpoints/tiny' \
|
||||
--augment_conf_path='conf/augmentation.config' \
|
||||
--specgram_type='linear' \
|
||||
--shuffle_method='batch_shuffle_clipped'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Fail to do inference!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,39 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
pushd ../.. > /dev/null
|
||||
|
||||
# grid-search for hyper-parameters in language model
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
python -u tools/tune.py \
|
||||
--num_samples=100 \
|
||||
--trainer_count=8 \
|
||||
--beam_size=500 \
|
||||
--num_proc_bsearch=12 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--num_alphas=14 \
|
||||
--num_betas=20 \
|
||||
--alpha_from=0.1 \
|
||||
--alpha_to=0.36 \
|
||||
--beta_from=0.05 \
|
||||
--beta_to=1.0 \
|
||||
--cutoff_prob=0.99 \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=True \
|
||||
--tune_manifest='data/tiny/manifest.tiny' \
|
||||
--mean_std_path='data/tiny/mean_std.npz' \
|
||||
--vocab_path='data/tiny/vocab.txt' \
|
||||
--model_path='checkpoints/params.pass-9.tar.gz' \
|
||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||
--error_rate_type='wer' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in tuning!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -1,177 +0,0 @@
|
||||
"""Contains DeepSpeech2 layers."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle.v2 as paddle
|
||||
|
||||
|
||||
def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride,
|
||||
padding, act):
|
||||
"""Convolution layer with batch normalization.
|
||||
|
||||
:param input: Input layer.
|
||||
:type input: LayerOutput
|
||||
:param filter_size: The x dimension of a filter kernel. Or input a tuple for
|
||||
two image dimension.
|
||||
:type filter_size: int|tuple|list
|
||||
:param num_channels_in: Number of input channels.
|
||||
:type num_channels_in: int
|
||||
:type num_channels_out: Number of output channels.
|
||||
:type num_channels_in: out
|
||||
:param padding: The x dimension of the padding. Or input a tuple for two
|
||||
image dimension.
|
||||
:type padding: int|tuple|list
|
||||
:param act: Activation type.
|
||||
:type act: BaseActivation
|
||||
:return: Batch norm layer after convolution layer.
|
||||
:rtype: LayerOutput
|
||||
"""
|
||||
conv_layer = paddle.layer.img_conv(
|
||||
input=input,
|
||||
filter_size=filter_size,
|
||||
num_channels=num_channels_in,
|
||||
num_filters=num_channels_out,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
act=paddle.activation.Linear(),
|
||||
bias_attr=False)
|
||||
return paddle.layer.batch_norm(input=conv_layer, act=act)
|
||||
|
||||
|
||||
def bidirectional_simple_rnn_bn_layer(name, input, size, act):
|
||||
"""Bidirectonal simple rnn layer with sequence-wise batch normalization.
|
||||
The batch normalization is only performed on input-state weights.
|
||||
|
||||
:param name: Name of the layer.
|
||||
:type name: string
|
||||
:param input: Input layer.
|
||||
:type input: LayerOutput
|
||||
:param size: Number of RNN cells.
|
||||
:type size: int
|
||||
:param act: Activation type.
|
||||
:type act: BaseActivation
|
||||
:return: Bidirectional simple rnn layer.
|
||||
:rtype: LayerOutput
|
||||
"""
|
||||
# input-hidden weights shared across bi-direcitonal rnn.
|
||||
input_proj = paddle.layer.fc(
|
||||
input=input, size=size, act=paddle.activation.Linear(), bias_attr=False)
|
||||
# batch norm is only performed on input-state projection
|
||||
input_proj_bn = paddle.layer.batch_norm(
|
||||
input=input_proj, act=paddle.activation.Linear())
|
||||
# forward and backward in time
|
||||
forward_simple_rnn = paddle.layer.recurrent(
|
||||
input=input_proj_bn, act=act, reverse=False)
|
||||
backward_simple_rnn = paddle.layer.recurrent(
|
||||
input=input_proj_bn, act=act, reverse=True)
|
||||
return paddle.layer.concat(input=[forward_simple_rnn, backward_simple_rnn])
|
||||
|
||||
|
||||
def conv_group(input, num_stacks):
|
||||
"""Convolution group with stacked convolution layers.
|
||||
|
||||
:param input: Input layer.
|
||||
:type input: LayerOutput
|
||||
:param num_stacks: Number of stacked convolution layers.
|
||||
:type num_stacks: int
|
||||
:return: Output layer of the convolution group.
|
||||
:rtype: LayerOutput
|
||||
"""
|
||||
conv = conv_bn_layer(
|
||||
input=input,
|
||||
filter_size=(11, 41),
|
||||
num_channels_in=1,
|
||||
num_channels_out=32,
|
||||
stride=(3, 2),
|
||||
padding=(5, 20),
|
||||
act=paddle.activation.BRelu())
|
||||
for i in xrange(num_stacks - 1):
|
||||
conv = conv_bn_layer(
|
||||
input=conv,
|
||||
filter_size=(11, 21),
|
||||
num_channels_in=32,
|
||||
num_channels_out=32,
|
||||
stride=(1, 2),
|
||||
padding=(5, 10),
|
||||
act=paddle.activation.BRelu())
|
||||
output_num_channels = 32
|
||||
output_height = 160 // pow(2, num_stacks) + 1
|
||||
return conv, output_num_channels, output_height
|
||||
|
||||
|
||||
def rnn_group(input, size, num_stacks):
|
||||
"""RNN group with stacked bidirectional simple RNN layers.
|
||||
|
||||
:param input: Input layer.
|
||||
:type input: LayerOutput
|
||||
:param size: Number of RNN cells in each layer.
|
||||
:type size: int
|
||||
:param num_stacks: Number of stacked rnn layers.
|
||||
:type num_stacks: int
|
||||
:return: Output layer of the RNN group.
|
||||
:rtype: LayerOutput
|
||||
"""
|
||||
output = input
|
||||
for i in xrange(num_stacks):
|
||||
output = bidirectional_simple_rnn_bn_layer(
|
||||
name=str(i), input=output, size=size, act=paddle.activation.BRelu())
|
||||
return output
|
||||
|
||||
|
||||
def deep_speech2(audio_data,
|
||||
text_data,
|
||||
dict_size,
|
||||
num_conv_layers=2,
|
||||
num_rnn_layers=3,
|
||||
rnn_size=256):
|
||||
"""
|
||||
The whole DeepSpeech2 model structure (a simplified version).
|
||||
|
||||
:param audio_data: Audio spectrogram data layer.
|
||||
:type audio_data: LayerOutput
|
||||
:param text_data: Transcription text data layer.
|
||||
:type text_data: LayerOutput
|
||||
:param dict_size: Dictionary size for tokenized transcription.
|
||||
:type dict_size: int
|
||||
:param num_conv_layers: Number of stacking convolution layers.
|
||||
:type num_conv_layers: int
|
||||
:param num_rnn_layers: Number of stacking RNN layers.
|
||||
:type num_rnn_layers: int
|
||||
:param rnn_size: RNN layer size (number of RNN cells).
|
||||
:type rnn_size: int
|
||||
:return: A tuple of an output unnormalized log probability layer (
|
||||
before softmax) and a ctc cost layer.
|
||||
:rtype: tuple of LayerOutput
|
||||
"""
|
||||
# convolution group
|
||||
conv_group_output, conv_group_num_channels, conv_group_height = conv_group(
|
||||
input=audio_data, num_stacks=num_conv_layers)
|
||||
# convert data form convolution feature map to sequence of vectors
|
||||
conv2seq = paddle.layer.block_expand(
|
||||
input=conv_group_output,
|
||||
num_channels=conv_group_num_channels,
|
||||
stride_x=1,
|
||||
stride_y=1,
|
||||
block_x=1,
|
||||
block_y=conv_group_height)
|
||||
# rnn group
|
||||
rnn_group_output = rnn_group(
|
||||
input=conv2seq, size=rnn_size, num_stacks=num_rnn_layers)
|
||||
fc = paddle.layer.fc(
|
||||
input=rnn_group_output,
|
||||
size=dict_size + 1,
|
||||
act=paddle.activation.Linear(),
|
||||
bias_attr=True)
|
||||
# probability distribution with softmax
|
||||
log_probs = paddle.layer.mixed(
|
||||
input=paddle.layer.identity_projection(input=fc),
|
||||
act=paddle.activation.Softmax())
|
||||
# ctc cost
|
||||
ctc_loss = paddle.layer.warp_ctc(
|
||||
input=fc,
|
||||
label=text_data,
|
||||
size=dict_size + 1,
|
||||
blank=dict_size,
|
||||
norm_by_times=True)
|
||||
return log_probs, ctc_loss
|
@ -1,19 +0,0 @@
|
||||
echo "Downloading language model ..."
|
||||
|
||||
mkdir data
|
||||
|
||||
LM=common_crawl_00.prune01111.trie.klm
|
||||
MD5="099a601759d467cd0a8523ff939819c5"
|
||||
|
||||
wget -c http://paddlepaddle.bj.bcebos.com/model_zoo/speech/$LM -P ./data
|
||||
|
||||
echo "Checking md5sum ..."
|
||||
md5_tmp=`md5sum ./data/$LM | awk -F[' '] '{print $1}'`
|
||||
|
||||
if [ $MD5 != $md5_tmp ]; then
|
||||
echo "Fail to download the language model!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
|
@ -0,0 +1,274 @@
|
||||
"""Contains DeepSpeech2 layers and networks."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle.v2 as paddle
|
||||
|
||||
|
||||
def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride,
|
||||
padding, act):
|
||||
"""Convolution layer with batch normalization.
|
||||
|
||||
:param input: Input layer.
|
||||
:type input: LayerOutput
|
||||
:param filter_size: The x dimension of a filter kernel. Or input a tuple for
|
||||
two image dimension.
|
||||
:type filter_size: int|tuple|list
|
||||
:param num_channels_in: Number of input channels.
|
||||
:type num_channels_in: int
|
||||
:type num_channels_out: Number of output channels.
|
||||
:type num_channels_in: out
|
||||
:param padding: The x dimension of the padding. Or input a tuple for two
|
||||
image dimension.
|
||||
:type padding: int|tuple|list
|
||||
:param act: Activation type.
|
||||
:type act: BaseActivation
|
||||
:return: Batch norm layer after convolution layer.
|
||||
:rtype: LayerOutput
|
||||
"""
|
||||
conv_layer = paddle.layer.img_conv(
|
||||
input=input,
|
||||
filter_size=filter_size,
|
||||
num_channels=num_channels_in,
|
||||
num_filters=num_channels_out,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
act=paddle.activation.Linear(),
|
||||
bias_attr=False)
|
||||
return paddle.layer.batch_norm(input=conv_layer, act=act)
|
||||
|
||||
|
||||
def bidirectional_simple_rnn_bn_layer(name, input, size, act, share_weights):
|
||||
"""Bidirectonal simple rnn layer with sequence-wise batch normalization.
|
||||
The batch normalization is only performed on input-state weights.
|
||||
|
||||
:param name: Name of the layer.
|
||||
:type name: string
|
||||
:param input: Input layer.
|
||||
:type input: LayerOutput
|
||||
:param size: Number of RNN cells.
|
||||
:type size: int
|
||||
:param act: Activation type.
|
||||
:type act: BaseActivation
|
||||
:param share_weights: Whether to share input-hidden weights between
|
||||
forward and backward directional RNNs.
|
||||
:type share_weights: bool
|
||||
:return: Bidirectional simple rnn layer.
|
||||
:rtype: LayerOutput
|
||||
"""
|
||||
if share_weights:
|
||||
# input-hidden weights shared between bi-direcitonal rnn.
|
||||
input_proj = paddle.layer.fc(
|
||||
input=input,
|
||||
size=size,
|
||||
act=paddle.activation.Linear(),
|
||||
bias_attr=False)
|
||||
# batch norm is only performed on input-state projection
|
||||
input_proj_bn = paddle.layer.batch_norm(
|
||||
input=input_proj, act=paddle.activation.Linear())
|
||||
# forward and backward in time
|
||||
forward_simple_rnn = paddle.layer.recurrent(
|
||||
input=input_proj_bn, act=act, reverse=False)
|
||||
backward_simple_rnn = paddle.layer.recurrent(
|
||||
input=input_proj_bn, act=act, reverse=True)
|
||||
|
||||
else:
|
||||
input_proj_forward = paddle.layer.fc(
|
||||
input=input,
|
||||
size=size,
|
||||
act=paddle.activation.Linear(),
|
||||
bias_attr=False)
|
||||
input_proj_backward = paddle.layer.fc(
|
||||
input=input,
|
||||
size=size,
|
||||
act=paddle.activation.Linear(),
|
||||
bias_attr=False)
|
||||
# batch norm is only performed on input-state projection
|
||||
input_proj_bn_forward = paddle.layer.batch_norm(
|
||||
input=input_proj_forward, act=paddle.activation.Linear())
|
||||
input_proj_bn_backward = paddle.layer.batch_norm(
|
||||
input=input_proj_backward, act=paddle.activation.Linear())
|
||||
# forward and backward in time
|
||||
forward_simple_rnn = paddle.layer.recurrent(
|
||||
input=input_proj_bn_forward, act=act, reverse=False)
|
||||
backward_simple_rnn = paddle.layer.recurrent(
|
||||
input=input_proj_bn_backward, act=act, reverse=True)
|
||||
|
||||
return paddle.layer.concat(input=[forward_simple_rnn, backward_simple_rnn])
|
||||
|
||||
|
||||
def bidirectional_gru_bn_layer(name, input, size, act):
|
||||
"""Bidirectonal gru layer with sequence-wise batch normalization.
|
||||
The batch normalization is only performed on input-state weights.
|
||||
|
||||
:param name: Name of the layer.
|
||||
:type name: string
|
||||
:param input: Input layer.
|
||||
:type input: LayerOutput
|
||||
:param size: Number of RNN cells.
|
||||
:type size: int
|
||||
:param act: Activation type.
|
||||
:type act: BaseActivation
|
||||
:return: Bidirectional simple rnn layer.
|
||||
:rtype: LayerOutput
|
||||
"""
|
||||
input_proj_forward = paddle.layer.fc(
|
||||
input=input,
|
||||
size=size * 3,
|
||||
act=paddle.activation.Linear(),
|
||||
bias_attr=False)
|
||||
input_proj_backward = paddle.layer.fc(
|
||||
input=input,
|
||||
size=size * 3,
|
||||
act=paddle.activation.Linear(),
|
||||
bias_attr=False)
|
||||
# batch norm is only performed on input-related projections
|
||||
input_proj_bn_forward = paddle.layer.batch_norm(
|
||||
input=input_proj_forward, act=paddle.activation.Linear())
|
||||
input_proj_bn_backward = paddle.layer.batch_norm(
|
||||
input=input_proj_backward, act=paddle.activation.Linear())
|
||||
# forward and backward in time
|
||||
forward_gru = paddle.layer.grumemory(
|
||||
input=input_proj_bn_forward, act=act, reverse=False)
|
||||
backward_gru = paddle.layer.grumemory(
|
||||
input=input_proj_bn_backward, act=act, reverse=True)
|
||||
return paddle.layer.concat(input=[forward_gru, backward_gru])
|
||||
|
||||
|
||||
def conv_group(input, num_stacks):
|
||||
"""Convolution group with stacked convolution layers.
|
||||
|
||||
:param input: Input layer.
|
||||
:type input: LayerOutput
|
||||
:param num_stacks: Number of stacked convolution layers.
|
||||
:type num_stacks: int
|
||||
:return: Output layer of the convolution group.
|
||||
:rtype: LayerOutput
|
||||
"""
|
||||
conv = conv_bn_layer(
|
||||
input=input,
|
||||
filter_size=(11, 41),
|
||||
num_channels_in=1,
|
||||
num_channels_out=32,
|
||||
stride=(3, 2),
|
||||
padding=(5, 20),
|
||||
act=paddle.activation.BRelu())
|
||||
for i in xrange(num_stacks - 1):
|
||||
conv = conv_bn_layer(
|
||||
input=conv,
|
||||
filter_size=(11, 21),
|
||||
num_channels_in=32,
|
||||
num_channels_out=32,
|
||||
stride=(1, 2),
|
||||
padding=(5, 10),
|
||||
act=paddle.activation.BRelu())
|
||||
output_num_channels = 32
|
||||
output_height = 160 // pow(2, num_stacks) + 1
|
||||
return conv, output_num_channels, output_height
|
||||
|
||||
|
||||
def rnn_group(input, size, num_stacks, use_gru, share_rnn_weights):
|
||||
"""RNN group with stacked bidirectional simple RNN layers.
|
||||
|
||||
:param input: Input layer.
|
||||
:type input: LayerOutput
|
||||
:param size: Number of RNN cells in each layer.
|
||||
:type size: int
|
||||
:param num_stacks: Number of stacked rnn layers.
|
||||
:type num_stacks: int
|
||||
:param use_gru: Use gru if set True. Use simple rnn if set False.
|
||||
:type use_gru: bool
|
||||
:param share_rnn_weights: Whether to share input-hidden weights between
|
||||
forward and backward directional RNNs.
|
||||
It is only available when use_gru=False.
|
||||
:type share_weights: bool
|
||||
:return: Output layer of the RNN group.
|
||||
:rtype: LayerOutput
|
||||
"""
|
||||
output = input
|
||||
for i in xrange(num_stacks):
|
||||
if use_gru:
|
||||
output = bidirectional_gru_bn_layer(
|
||||
name=str(i),
|
||||
input=output,
|
||||
size=size,
|
||||
act=paddle.activation.Relu())
|
||||
# BRelu does not support hppl, need to add later. Use Relu instead.
|
||||
else:
|
||||
output = bidirectional_simple_rnn_bn_layer(
|
||||
name=str(i),
|
||||
input=output,
|
||||
size=size,
|
||||
act=paddle.activation.BRelu(),
|
||||
share_weights=share_rnn_weights)
|
||||
return output
|
||||
|
||||
|
||||
def deep_speech_v2_network(audio_data,
|
||||
text_data,
|
||||
dict_size,
|
||||
num_conv_layers=2,
|
||||
num_rnn_layers=3,
|
||||
rnn_size=256,
|
||||
use_gru=False,
|
||||
share_rnn_weights=True):
|
||||
"""The DeepSpeech2 network structure.
|
||||
|
||||
:param audio_data: Audio spectrogram data layer.
|
||||
:type audio_data: LayerOutput
|
||||
:param text_data: Transcription text data layer.
|
||||
:type text_data: LayerOutput
|
||||
:param dict_size: Dictionary size for tokenized transcription.
|
||||
:type dict_size: int
|
||||
:param num_conv_layers: Number of stacking convolution layers.
|
||||
:type num_conv_layers: int
|
||||
:param num_rnn_layers: Number of stacking RNN layers.
|
||||
:type num_rnn_layers: int
|
||||
:param rnn_size: RNN layer size (number of RNN cells).
|
||||
:type rnn_size: int
|
||||
:param use_gru: Use gru if set True. Use simple rnn if set False.
|
||||
:type use_gru: bool
|
||||
:param share_rnn_weights: Whether to share input-hidden weights between
|
||||
forward and backward direction RNNs.
|
||||
It is only available when use_gru=False.
|
||||
:type share_weights: bool
|
||||
:return: A tuple of an output unnormalized log probability layer (
|
||||
before softmax) and a ctc cost layer.
|
||||
:rtype: tuple of LayerOutput
|
||||
"""
|
||||
# convolution group
|
||||
conv_group_output, conv_group_num_channels, conv_group_height = conv_group(
|
||||
input=audio_data, num_stacks=num_conv_layers)
|
||||
# convert data form convolution feature map to sequence of vectors
|
||||
conv2seq = paddle.layer.block_expand(
|
||||
input=conv_group_output,
|
||||
num_channels=conv_group_num_channels,
|
||||
stride_x=1,
|
||||
stride_y=1,
|
||||
block_x=1,
|
||||
block_y=conv_group_height)
|
||||
# rnn group
|
||||
rnn_group_output = rnn_group(
|
||||
input=conv2seq,
|
||||
size=rnn_size,
|
||||
num_stacks=num_rnn_layers,
|
||||
use_gru=use_gru,
|
||||
share_rnn_weights=share_rnn_weights)
|
||||
fc = paddle.layer.fc(
|
||||
input=rnn_group_output,
|
||||
size=dict_size + 1,
|
||||
act=paddle.activation.Linear(),
|
||||
bias_attr=True)
|
||||
# probability distribution with softmax
|
||||
log_probs = paddle.layer.mixed(
|
||||
input=paddle.layer.identity_projection(input=fc),
|
||||
act=paddle.activation.Softmax())
|
||||
# ctc cost
|
||||
ctc_loss = paddle.layer.warp_ctc(
|
||||
input=fc,
|
||||
label=text_data,
|
||||
size=dict_size + 1,
|
||||
blank=dict_size,
|
||||
norm_by_times=True)
|
||||
return log_probs, ctc_loss
|
@ -0,0 +1,19 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
source ../../utils/utility.sh
|
||||
|
||||
URL='http://cloud.dlnel.org/filepub/?uuid=6c83b9d8-3255-4adf-9726-0fe0be3d0274'
|
||||
MD5=28521a58552885a81cf92a1e9b133a71
|
||||
TARGET=./aishell_model.tar.gz
|
||||
|
||||
|
||||
echo "Download Aishell model ..."
|
||||
download $URL $MD5 $TARGET
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Fail to download Aishell model!"
|
||||
exit 1
|
||||
fi
|
||||
tar -zxvf $TARGET
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,19 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
source ../../utils/utility.sh
|
||||
|
||||
URL='http://cloud.dlnel.org/filepub/?uuid=8e3cf742-2ff3-41ce-a49d-f6158cc06a23'
|
||||
MD5=2ef08f8b608a7c555592161fc14d81a6
|
||||
TARGET=./librispeech_model.tar.gz
|
||||
|
||||
|
||||
echo "Download LibriSpeech model ..."
|
||||
download $URL $MD5 $TARGET
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Fail to download LibriSpeech model!"
|
||||
exit 1
|
||||
fi
|
||||
tar -zxvf $TARGET
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,18 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
source ../../utils/utility.sh
|
||||
|
||||
URL=http://cloud.dlnel.org/filepub/?uuid=d21861e4-4ed6-45bb-ad8e-ae417a43195e
|
||||
MD5="29e02312deb2e59b3c8686c7966d4fe3"
|
||||
TARGET=./zh_giga.no_cna_cmn.prune01244.klm
|
||||
|
||||
|
||||
echo "Download language model ..."
|
||||
download $URL $MD5 $TARGET
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Fail to download the language model!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,18 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
source ../../utils/utility.sh
|
||||
|
||||
URL=http://paddlepaddle.bj.bcebos.com/model_zoo/speech/common_crawl_00.prune01111.trie.klm
|
||||
MD5="099a601759d467cd0a8523ff939819c5"
|
||||
TARGET=./common_crawl_00.prune01111.trie.klm
|
||||
|
||||
|
||||
echo "Download language model ..."
|
||||
download $URL $MD5 $TARGET
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Fail to download the language model!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,127 @@
|
||||
"""Evaluation for DeepSpeech2 model."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import functools
|
||||
import paddle.v2 as paddle
|
||||
from data_utils.data import DataGenerator
|
||||
from model_utils.model import DeepSpeech2Model
|
||||
from utils.error_rate import wer, cer
|
||||
from utils.utility import add_arguments, print_arguments
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
add_arg = functools.partial(add_arguments, argparser=parser)
|
||||
# yapf: disable
|
||||
add_arg('batch_size', int, 128, "Minibatch size.")
|
||||
add_arg('trainer_count', int, 8, "# of Trainers (CPUs or GPUs).")
|
||||
add_arg('beam_size', int, 500, "Beam search width.")
|
||||
add_arg('num_proc_bsearch', int, 12, "# of CPUs for beam search.")
|
||||
add_arg('num_proc_data', int, 12, "# of CPUs for data preprocessing.")
|
||||
add_arg('num_conv_layers', int, 2, "# of convolution layers.")
|
||||
add_arg('num_rnn_layers', int, 3, "# of recurrent layers.")
|
||||
add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.")
|
||||
add_arg('alpha', float, 2.15, "Coef of LM for beam search.")
|
||||
add_arg('beta', float, 0.35, "Coef of WC for beam search.")
|
||||
add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.")
|
||||
add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.")
|
||||
add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.")
|
||||
add_arg('use_gpu', bool, True, "Use GPU or not.")
|
||||
add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across "
|
||||
"bi-directional RNNs. Not for GRU.")
|
||||
add_arg('test_manifest', str,
|
||||
'data/librispeech/manifest.test-clean',
|
||||
"Filepath of manifest to evaluate.")
|
||||
add_arg('mean_std_path', str,
|
||||
'data/librispeech/mean_std.npz',
|
||||
"Filepath of normalizer's mean & std.")
|
||||
add_arg('vocab_path', str,
|
||||
'data/librispeech/vocab.txt',
|
||||
"Filepath of vocabulary.")
|
||||
add_arg('model_path', str,
|
||||
'./checkpoints/libri/params.latest.tar.gz',
|
||||
"If None, the training starts from scratch, "
|
||||
"otherwise, it resumes from the pre-trained model.")
|
||||
add_arg('lang_model_path', str,
|
||||
'models/lm/common_crawl_00.prune01111.trie.klm',
|
||||
"Filepath for language model.")
|
||||
add_arg('decoding_method', str,
|
||||
'ctc_beam_search',
|
||||
"Decoding method. Options: ctc_beam_search, ctc_greedy",
|
||||
choices = ['ctc_beam_search', 'ctc_greedy'])
|
||||
add_arg('error_rate_type', str,
|
||||
'wer',
|
||||
"Error rate type for evaluation.",
|
||||
choices=['wer', 'cer'])
|
||||
add_arg('specgram_type', str,
|
||||
'linear',
|
||||
"Audio feature type. Options: linear, mfcc.",
|
||||
choices=['linear', 'mfcc'])
|
||||
# yapf: disable
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def evaluate():
|
||||
"""Evaluate on whole test data for DeepSpeech2."""
|
||||
data_generator = DataGenerator(
|
||||
vocab_filepath=args.vocab_path,
|
||||
mean_std_filepath=args.mean_std_path,
|
||||
augmentation_config='{}',
|
||||
specgram_type=args.specgram_type,
|
||||
num_threads=args.num_proc_data)
|
||||
batch_reader = data_generator.batch_reader_creator(
|
||||
manifest_path=args.test_manifest,
|
||||
batch_size=args.batch_size,
|
||||
min_batch_size=1,
|
||||
sortagrad=False,
|
||||
shuffle_method=None)
|
||||
|
||||
ds2_model = DeepSpeech2Model(
|
||||
vocab_size=data_generator.vocab_size,
|
||||
num_conv_layers=args.num_conv_layers,
|
||||
num_rnn_layers=args.num_rnn_layers,
|
||||
rnn_layer_size=args.rnn_layer_size,
|
||||
use_gru=args.use_gru,
|
||||
pretrained_model_path=args.model_path,
|
||||
share_rnn_weights=args.share_rnn_weights)
|
||||
|
||||
# decoders only accept string encoded in utf-8
|
||||
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
|
||||
|
||||
error_rate_func = cer if args.error_rate_type == 'cer' else wer
|
||||
error_sum, num_ins = 0.0, 0
|
||||
for infer_data in batch_reader():
|
||||
result_transcripts = ds2_model.infer_batch(
|
||||
infer_data=infer_data,
|
||||
decoding_method=args.decoding_method,
|
||||
beam_alpha=args.alpha,
|
||||
beam_beta=args.beta,
|
||||
beam_size=args.beam_size,
|
||||
cutoff_prob=args.cutoff_prob,
|
||||
cutoff_top_n=args.cutoff_top_n,
|
||||
vocab_list=vocab_list,
|
||||
language_model_path=args.lang_model_path,
|
||||
num_processes=args.num_proc_bsearch)
|
||||
target_transcripts = [
|
||||
''.join([data_generator.vocab_list[token] for token in transcript])
|
||||
for _, transcript in infer_data
|
||||
]
|
||||
for target, result in zip(target_transcripts, result_transcripts):
|
||||
error_sum += error_rate_func(target, result)
|
||||
num_ins += 1
|
||||
print("Error rate [%s] (%d/?) = %f" %
|
||||
(args.error_rate_type, num_ins, error_sum / num_ins))
|
||||
print("Final error rate [%s] (%d/%d) = %f" %
|
||||
(args.error_rate_type, num_ins, num_ins, error_sum / num_ins))
|
||||
|
||||
ds2_model.logger.info("finish evaluation")
|
||||
|
||||
def main():
|
||||
print_arguments(args)
|
||||
paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
|
||||
evaluate()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -1,59 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Test error rate."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import error_rate
|
||||
|
||||
|
||||
class TestParse(unittest.TestCase):
|
||||
def test_wer_1(self):
|
||||
ref = 'i UM the PHONE IS i LEFT THE portable PHONE UPSTAIRS last night'
|
||||
hyp = 'i GOT IT TO the FULLEST i LOVE TO portable FROM OF STORES last night'
|
||||
word_error_rate = error_rate.wer(ref, hyp)
|
||||
self.assertTrue(abs(word_error_rate - 0.769230769231) < 1e-6)
|
||||
|
||||
def test_wer_2(self):
|
||||
ref = 'i UM the PHONE IS i LEFT THE portable PHONE UPSTAIRS last night'
|
||||
word_error_rate = error_rate.wer(ref, ref)
|
||||
self.assertEqual(word_error_rate, 0.0)
|
||||
|
||||
def test_wer_3(self):
|
||||
ref = ' '
|
||||
hyp = 'Hypothesis sentence'
|
||||
with self.assertRaises(ValueError):
|
||||
word_error_rate = error_rate.wer(ref, hyp)
|
||||
|
||||
def test_cer_1(self):
|
||||
ref = 'werewolf'
|
||||
hyp = 'weae wolf'
|
||||
char_error_rate = error_rate.cer(ref, hyp)
|
||||
self.assertTrue(abs(char_error_rate - 0.25) < 1e-6)
|
||||
|
||||
def test_cer_2(self):
|
||||
ref = 'werewolf'
|
||||
char_error_rate = error_rate.cer(ref, ref)
|
||||
self.assertEqual(char_error_rate, 0.0)
|
||||
|
||||
def test_cer_3(self):
|
||||
ref = u'我是中国人'
|
||||
hyp = u'我是 美洲人'
|
||||
char_error_rate = error_rate.cer(ref, hyp)
|
||||
self.assertTrue(abs(char_error_rate - 0.6) < 1e-6)
|
||||
|
||||
def test_cer_4(self):
|
||||
ref = u'我是中国人'
|
||||
char_error_rate = error_rate.cer(ref, ref)
|
||||
self.assertFalse(char_error_rate, 0.0)
|
||||
|
||||
def test_cer_5(self):
|
||||
ref = ''
|
||||
hyp = 'Hypothesis'
|
||||
with self.assertRaises(ValueError):
|
||||
char_error_rate = error_rate.cer(ref, hyp)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -1,23 +0,0 @@
|
||||
"""Test Setup."""
|
||||
import unittest
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
|
||||
class TestSetup(unittest.TestCase):
|
||||
def test_soundfile(self):
|
||||
import soundfile as sf
|
||||
# floating point data is typically limited to the interval [-1.0, 1.0],
|
||||
# but smaller/larger values are supported as well
|
||||
data = np.array([[1.75, -1.75], [1.0, -1.0], [0.5, -0.5],
|
||||
[0.25, -0.25]])
|
||||
file = 'test.wav'
|
||||
sf.write(file, data, 44100, format='WAV', subtype='FLOAT')
|
||||
read, fs = sf.read(file)
|
||||
self.assertTrue(np.all(read == data))
|
||||
self.assertEqual(fs, 44100)
|
||||
os.remove(file)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,19 @@
|
||||
"""Set up paths for DS2"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os.path
|
||||
import sys
|
||||
|
||||
|
||||
def add_path(path):
|
||||
if path not in sys.path:
|
||||
sys.path.insert(0, path)
|
||||
|
||||
|
||||
this_dir = os.path.dirname(__file__)
|
||||
|
||||
# Add project path to PYTHONPATH
|
||||
proj_path = os.path.join(this_dir, '..')
|
||||
add_path(proj_path)
|
@ -0,0 +1,58 @@
|
||||
"""Build vocabulary from manifest files.
|
||||
|
||||
Each item in vocabulary file is a character.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import functools
|
||||
import codecs
|
||||
import json
|
||||
from collections import Counter
|
||||
import os.path
|
||||
import _init_paths
|
||||
from data_utils.utility import read_manifest
|
||||
from utils.utility import add_arguments, print_arguments
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
add_arg = functools.partial(add_arguments, argparser=parser)
|
||||
# yapf: disable
|
||||
add_arg('count_threshold', int, 0, "Truncation threshold for char counts.")
|
||||
add_arg('vocab_path', str,
|
||||
'data/librispeech/vocab.txt',
|
||||
"Filepath to write the vocabulary.")
|
||||
add_arg('manifest_paths', str,
|
||||
None,
|
||||
"Filepaths of manifests for building vocabulary. "
|
||||
"You can provide multiple manifest files.",
|
||||
nargs='+',
|
||||
required=True)
|
||||
# yapf: disable
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def count_manifest(counter, manifest_path):
|
||||
manifest_jsons = read_manifest(manifest_path)
|
||||
for line_json in manifest_jsons:
|
||||
for char in line_json['text']:
|
||||
counter.update(char)
|
||||
|
||||
|
||||
def main():
|
||||
print_arguments(args)
|
||||
|
||||
counter = Counter()
|
||||
for manifest_path in args.manifest_paths:
|
||||
count_manifest(counter, manifest_path)
|
||||
|
||||
count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True)
|
||||
with codecs.open(args.vocab_path, 'w', 'utf-8') as fout:
|
||||
for char, count in count_sorted:
|
||||
if count < args.count_threshold: break
|
||||
fout.write(char + '\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,51 @@
|
||||
"""Compute mean and std for feature normalizer, and save to file."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import functools
|
||||
import _init_paths
|
||||
from data_utils.normalizer import FeatureNormalizer
|
||||
from data_utils.augmentor.augmentation import AugmentationPipeline
|
||||
from data_utils.featurizer.audio_featurizer import AudioFeaturizer
|
||||
from utils.utility import add_arguments, print_arguments
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
add_arg = functools.partial(add_arguments, argparser=parser)
|
||||
# yapf: disable
|
||||
add_arg('num_samples', int, 2000, "# of samples to for statistics.")
|
||||
add_arg('specgram_type', str,
|
||||
'linear',
|
||||
"Audio feature type. Options: linear, mfcc.",
|
||||
choices=['linear', 'mfcc'])
|
||||
add_arg('manifest_path', str,
|
||||
'data/librispeech/manifest.train',
|
||||
"Filepath of manifest to compute normalizer's mean and stddev.")
|
||||
add_arg('output_path', str,
|
||||
'data/librispeech/mean_std.npz',
|
||||
"Filepath of write mean and stddev to (.npz).")
|
||||
# yapf: disable
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
print_arguments(args)
|
||||
|
||||
augmentation_pipeline = AugmentationPipeline('{}')
|
||||
audio_featurizer = AudioFeaturizer(specgram_type=args.specgram_type)
|
||||
|
||||
def augment_and_featurize(audio_segment):
|
||||
augmentation_pipeline.transform_audio(audio_segment)
|
||||
return audio_featurizer.featurize(audio_segment)
|
||||
|
||||
normalizer = FeatureNormalizer(
|
||||
mean_std_filepath=None,
|
||||
manifest_path=args.manifest_path,
|
||||
featurize_func=augment_and_featurize,
|
||||
num_samples=args.num_samples)
|
||||
normalizer.write_to_file(args.output_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,30 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
BATCH_SIZE_PER_GPU=64
|
||||
MIN_DURATION=6.0
|
||||
MAX_DURATION=7.0
|
||||
|
||||
function join_by { local IFS="$1"; shift; echo "$*"; }
|
||||
|
||||
for NUM_GPUS in 16 8 4 2 1
|
||||
do
|
||||
DEVICES=$(join_by , $(seq 0 $(($NUM_GPUS-1))))
|
||||
BATCH_SIZE=$(($BATCH_SIZE_PER_GPU * $NUM_GPUS))
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$DEVICES \
|
||||
python train.py \
|
||||
--batch_size=$BATCH_SIZE \
|
||||
--num_passes=1 \
|
||||
--test_off=True \
|
||||
--trainer_count=$NUM_GPUS \
|
||||
--min_duration=$MIN_DURATION \
|
||||
--max_duration=$MAX_DURATION > tmp.log 2>&1
|
||||
|
||||
if [ $? -ne 0 ];then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cat tmp.log | grep "Time" | awk '{print "GPU Num: " "'"$NUM_GPUS"'" " Time: "$3}'
|
||||
|
||||
rm tmp.log
|
||||
done
|
@ -0,0 +1,131 @@
|
||||
"""Beam search parameters tuning for DeepSpeech2 model."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import argparse
|
||||
import functools
|
||||
import paddle.v2 as paddle
|
||||
import _init_paths
|
||||
from data_utils.data import DataGenerator
|
||||
from model_utils.model import DeepSpeech2Model
|
||||
from utils.error_rate import wer
|
||||
from utils.utility import add_arguments, print_arguments
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
add_arg = functools.partial(add_arguments, argparser=parser)
|
||||
# yapf: disable
|
||||
add_arg('num_samples', int, 100, "# of samples to infer.")
|
||||
add_arg('trainer_count', int, 8, "# of Trainers (CPUs or GPUs).")
|
||||
add_arg('beam_size', int, 500, "Beam search width.")
|
||||
add_arg('num_proc_bsearch', int, 12, "# of CPUs for beam search.")
|
||||
add_arg('num_conv_layers', int, 2, "# of convolution layers.")
|
||||
add_arg('num_rnn_layers', int, 3, "# of recurrent layers.")
|
||||
add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.")
|
||||
add_arg('num_alphas', int, 14, "# of alpha candidates for tuning.")
|
||||
add_arg('num_betas', int, 20, "# of beta candidates for tuning.")
|
||||
add_arg('alpha_from', float, 0.1, "Where alpha starts tuning from.")
|
||||
add_arg('alpha_to', float, 0.36, "Where alpha ends tuning with.")
|
||||
add_arg('beta_from', float, 0.05, "Where beta starts tuning from.")
|
||||
add_arg('beta_to', float, 1.0, "Where beta ends tuning with.")
|
||||
add_arg('cutoff_prob', float, 0.99, "Cutoff probability for pruning.")
|
||||
add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.")
|
||||
add_arg('use_gpu', bool, True, "Use GPU or not.")
|
||||
add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across "
|
||||
"bi-directional RNNs. Not for GRU.")
|
||||
add_arg('tune_manifest', str,
|
||||
'data/librispeech/manifest.dev',
|
||||
"Filepath of manifest to tune.")
|
||||
add_arg('mean_std_path', str,
|
||||
'data/librispeech/mean_std.npz',
|
||||
"Filepath of normalizer's mean & std.")
|
||||
add_arg('vocab_path', str,
|
||||
'data/librispeech/vocab.txt',
|
||||
"Filepath of vocabulary.")
|
||||
add_arg('lang_model_path', str,
|
||||
'models/lm/common_crawl_00.prune01111.trie.klm',
|
||||
"Filepath for language model.")
|
||||
add_arg('model_path', str,
|
||||
'./checkpoints/libri/params.latest.tar.gz',
|
||||
"If None, the training starts from scratch, "
|
||||
"otherwise, it resumes from the pre-trained model.")
|
||||
add_arg('error_rate_type', str,
|
||||
'wer',
|
||||
"Error rate type for evaluation.",
|
||||
choices=['wer', 'cer'])
|
||||
add_arg('specgram_type', str,
|
||||
'linear',
|
||||
"Audio feature type. Options: linear, mfcc.",
|
||||
choices=['linear', 'mfcc'])
|
||||
# yapf: disable
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def tune():
|
||||
"""Tune parameters alpha and beta on one minibatch."""
|
||||
if not args.num_alphas >= 0:
|
||||
raise ValueError("num_alphas must be non-negative!")
|
||||
if not args.num_betas >= 0:
|
||||
raise ValueError("num_betas must be non-negative!")
|
||||
|
||||
data_generator = DataGenerator(
|
||||
vocab_filepath=args.vocab_path,
|
||||
mean_std_filepath=args.mean_std_path,
|
||||
augmentation_config='{}',
|
||||
specgram_type=args.specgram_type,
|
||||
num_threads=1)
|
||||
batch_reader = data_generator.batch_reader_creator(
|
||||
manifest_path=args.tune_manifest,
|
||||
batch_size=args.num_samples,
|
||||
sortagrad=False,
|
||||
shuffle_method=None)
|
||||
tune_data = batch_reader().next()
|
||||
target_transcripts = [
|
||||
''.join([data_generator.vocab_list[token] for token in transcript])
|
||||
for _, transcript in tune_data
|
||||
]
|
||||
|
||||
ds2_model = DeepSpeech2Model(
|
||||
vocab_size=data_generator.vocab_size,
|
||||
num_conv_layers=args.num_conv_layers,
|
||||
num_rnn_layers=args.num_rnn_layers,
|
||||
rnn_layer_size=args.rnn_layer_size,
|
||||
use_gru=args.use_gru,
|
||||
pretrained_model_path=args.model_path,
|
||||
share_rnn_weights=args.share_rnn_weights)
|
||||
|
||||
# create grid for search
|
||||
cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas)
|
||||
cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas)
|
||||
params_grid = [(alpha, beta) for alpha in cand_alphas
|
||||
for beta in cand_betas]
|
||||
|
||||
## tune parameters in loop
|
||||
for alpha, beta in params_grid:
|
||||
result_transcripts = ds2_model.infer_batch(
|
||||
infer_data=tune_data,
|
||||
decoding_method='ctc_beam_search',
|
||||
beam_alpha=alpha,
|
||||
beam_beta=beta,
|
||||
beam_size=args.beam_size,
|
||||
cutoff_prob=args.cutoff_prob,
|
||||
vocab_list=data_generator.vocab_list,
|
||||
language_model_path=args.lang_model_path,
|
||||
num_processes=args.num_proc_bsearch)
|
||||
wer_sum, num_ins = 0.0, 0
|
||||
for target, result in zip(target_transcripts, result_transcripts):
|
||||
wer_sum += wer(target, result)
|
||||
num_ins += 1
|
||||
print("alpha = %f\tbeta = %f\tWER = %f" %
|
||||
(alpha, beta, wer_sum / num_ins))
|
||||
|
||||
|
||||
def main():
|
||||
print_arguments(args)
|
||||
paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
|
||||
tune()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -1,216 +0,0 @@
|
||||
"""Parameters tuning for DeepSpeech2 model."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import distutils.util
|
||||
import argparse
|
||||
import multiprocessing
|
||||
import paddle.v2 as paddle
|
||||
from data_utils.data import DataGenerator
|
||||
from model import DeepSpeech2Model
|
||||
from error_rate import wer
|
||||
import utils
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
default=128,
|
||||
type=int,
|
||||
help="Minibatch size for parameters tuning. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--num_conv_layers",
|
||||
default=2,
|
||||
type=int,
|
||||
help="Convolution layer number. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--num_rnn_layers",
|
||||
default=3,
|
||||
type=int,
|
||||
help="RNN layer number. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--rnn_layer_size",
|
||||
default=512,
|
||||
type=int,
|
||||
help="RNN layer cell number. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--use_gpu",
|
||||
default=True,
|
||||
type=distutils.util.strtobool,
|
||||
help="Use gpu or not. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--trainer_count",
|
||||
default=8,
|
||||
type=int,
|
||||
help="Trainer number. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--num_threads_data",
|
||||
default=1,
|
||||
type=int,
|
||||
help="Number of cpu threads for preprocessing data. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--num_processes_beam_search",
|
||||
default=multiprocessing.cpu_count(),
|
||||
type=int,
|
||||
help="Number of cpu processes for beam search. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--specgram_type",
|
||||
default='linear',
|
||||
type=str,
|
||||
help="Feature type of audio data: 'linear' (power spectrum)"
|
||||
" or 'mfcc'. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--mean_std_filepath",
|
||||
default='mean_std.npz',
|
||||
type=str,
|
||||
help="Manifest path for normalizer. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--tune_manifest_path",
|
||||
default='datasets/manifest.dev',
|
||||
type=str,
|
||||
help="Manifest path for tuning. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--model_filepath",
|
||||
default='checkpoints/params.latest.tar.gz',
|
||||
type=str,
|
||||
help="Model filepath. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--vocab_filepath",
|
||||
default='datasets/vocab/eng_vocab.txt',
|
||||
type=str,
|
||||
help="Vocabulary filepath. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--beam_size",
|
||||
default=500,
|
||||
type=int,
|
||||
help="Width for beam search decoding. (default: %(default)d)")
|
||||
parser.add_argument(
|
||||
"--language_model_path",
|
||||
default="lm/data/common_crawl_00.prune01111.trie.klm",
|
||||
type=str,
|
||||
help="Path for language model. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--alpha_from",
|
||||
default=0.1,
|
||||
type=float,
|
||||
help="Where alpha starts from. (default: %(default)f)")
|
||||
parser.add_argument(
|
||||
"--num_alphas",
|
||||
default=14,
|
||||
type=int,
|
||||
help="Number of candidate alphas. (default: %(default)d)")
|
||||
parser.add_argument(
|
||||
"--alpha_to",
|
||||
default=0.36,
|
||||
type=float,
|
||||
help="Where alpha ends with. (default: %(default)f)")
|
||||
parser.add_argument(
|
||||
"--beta_from",
|
||||
default=0.05,
|
||||
type=float,
|
||||
help="Where beta starts from. (default: %(default)f)")
|
||||
parser.add_argument(
|
||||
"--num_betas",
|
||||
default=20,
|
||||
type=float,
|
||||
help="Number of candidate betas. (default: %(default)d)")
|
||||
parser.add_argument(
|
||||
"--beta_to",
|
||||
default=1.0,
|
||||
type=float,
|
||||
help="Where beta ends with. (default: %(default)f)")
|
||||
parser.add_argument(
|
||||
"--cutoff_prob",
|
||||
default=0.99,
|
||||
type=float,
|
||||
help="The cutoff probability of pruning"
|
||||
"in beam search. (default: %(default)f)")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def tune():
|
||||
"""Tune parameters alpha and beta for the CTC beam search decoder
|
||||
incrementally. The optimal parameters up to now would be output real time
|
||||
at the end of each minibatch data, until all the development data is
|
||||
taken into account. And the tuning process can be terminated at any time
|
||||
as long as the two parameters get stable.
|
||||
"""
|
||||
if not args.num_alphas >= 0:
|
||||
raise ValueError("num_alphas must be non-negative!")
|
||||
if not args.num_betas >= 0:
|
||||
raise ValueError("num_betas must be non-negative!")
|
||||
|
||||
data_generator = DataGenerator(
|
||||
vocab_filepath=args.vocab_filepath,
|
||||
mean_std_filepath=args.mean_std_filepath,
|
||||
augmentation_config='{}',
|
||||
specgram_type=args.specgram_type,
|
||||
num_threads=args.num_threads_data)
|
||||
batch_reader = data_generator.batch_reader_creator(
|
||||
manifest_path=args.tune_manifest_path,
|
||||
batch_size=args.batch_size,
|
||||
sortagrad=False,
|
||||
shuffle_method=None)
|
||||
|
||||
ds2_model = DeepSpeech2Model(
|
||||
vocab_size=data_generator.vocab_size,
|
||||
num_conv_layers=args.num_conv_layers,
|
||||
num_rnn_layers=args.num_rnn_layers,
|
||||
rnn_layer_size=args.rnn_layer_size,
|
||||
pretrained_model_path=args.model_filepath)
|
||||
|
||||
# create grid for search
|
||||
cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas)
|
||||
cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas)
|
||||
params_grid = [(alpha, beta) for alpha in cand_alphas
|
||||
for beta in cand_betas]
|
||||
|
||||
wer_sum = [0.0 for i in xrange(len(params_grid))]
|
||||
ave_wer = [0.0 for i in xrange(len(params_grid))]
|
||||
num_ins = 0
|
||||
num_batches = 0
|
||||
## incremental tuning parameters over multiple batches
|
||||
for infer_data in batch_reader():
|
||||
target_transcripts = [
|
||||
''.join([data_generator.vocab_list[token] for token in transcript])
|
||||
for _, transcript in infer_data
|
||||
]
|
||||
|
||||
num_ins += len(target_transcripts)
|
||||
# grid search
|
||||
for index, (alpha, beta) in enumerate(params_grid):
|
||||
result_transcripts = ds2_model.infer_batch(
|
||||
infer_data=infer_data,
|
||||
decode_method='beam_search',
|
||||
beam_alpha=alpha,
|
||||
beam_beta=beta,
|
||||
beam_size=args.beam_size,
|
||||
cutoff_prob=args.cutoff_prob,
|
||||
vocab_list=data_generator.vocab_list,
|
||||
language_model_path=args.language_model_path,
|
||||
num_processes=args.num_processes_beam_search)
|
||||
|
||||
for target, result in zip(target_transcripts, result_transcripts):
|
||||
wer_sum[index] += wer(target, result)
|
||||
ave_wer[index] = wer_sum[index] / num_ins
|
||||
print("alpha = %f, beta = %f, WER = %f" %
|
||||
(alpha, beta, ave_wer[index]))
|
||||
|
||||
# output on-line tuning result at the the end of current batch
|
||||
ave_wer_min = min(ave_wer)
|
||||
min_index = ave_wer.index(ave_wer_min)
|
||||
print("Finish batch %d, optimal (alpha, beta, WER) = (%f, %f, %f)\n" %
|
||||
(num_batches, params_grid[min_index][0],
|
||||
params_grid[min_index][1], ave_wer_min))
|
||||
num_batches += 1
|
||||
|
||||
|
||||
def main():
|
||||
utils.print_arguments(args)
|
||||
paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count)
|
||||
tune()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -1,25 +0,0 @@
|
||||
"""Contains common utility functions."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
def print_arguments(args):
|
||||
"""Print argparse's arguments.
|
||||
|
||||
Usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("name", default="Jonh", type=str, help="User name.")
|
||||
args = parser.parse_args()
|
||||
print_arguments(args)
|
||||
|
||||
:param args: Input argparse.Namespace for printing.
|
||||
:type args: argparse.Namespace
|
||||
"""
|
||||
print("----- Configuration Arguments -----")
|
||||
for arg, value in vars(args).iteritems():
|
||||
print("%s: %s" % (arg, value))
|
||||
print("------------------------------------")
|
@ -0,0 +1,115 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Test error rate."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
from utils import error_rate
|
||||
|
||||
|
||||
class TestParse(unittest.TestCase):
|
||||
def test_wer_1(self):
|
||||
ref = 'i UM the PHONE IS i LEFT THE portable PHONE UPSTAIRS last night'
|
||||
hyp = 'i GOT IT TO the FULLEST i LOVE TO portable FROM OF STORES last '\
|
||||
'night'
|
||||
word_error_rate = error_rate.wer(ref, hyp)
|
||||
self.assertTrue(abs(word_error_rate - 0.769230769231) < 1e-6)
|
||||
|
||||
def test_wer_2(self):
|
||||
ref = 'as any in england i would say said gamewell proudly that is '\
|
||||
'in his day'
|
||||
hyp = 'as any in england i would say said came well proudly that is '\
|
||||
'in his day'
|
||||
word_error_rate = error_rate.wer(ref, hyp)
|
||||
self.assertTrue(abs(word_error_rate - 0.1333333) < 1e-6)
|
||||
|
||||
def test_wer_3(self):
|
||||
ref = 'the lieutenant governor lilburn w boggs afterward governor '\
|
||||
'was a pronounced mormon hater and throughout the period of '\
|
||||
'the troubles he manifested sympathy with the persecutors'
|
||||
hyp = 'the lieutenant governor little bit how bags afterward '\
|
||||
'governor was a pronounced warman hater and throughout the '\
|
||||
'period of th troubles he manifests sympathy with the '\
|
||||
'persecutors'
|
||||
word_error_rate = error_rate.wer(ref, hyp)
|
||||
self.assertTrue(abs(word_error_rate - 0.2692307692) < 1e-6)
|
||||
|
||||
def test_wer_4(self):
|
||||
ref = 'the wood flamed up splendidly under the large brewing copper '\
|
||||
'and it sighed so deeply'
|
||||
hyp = 'the wood flame do splendidly under the large brewing copper '\
|
||||
'and its side so deeply'
|
||||
word_error_rate = error_rate.wer(ref, hyp)
|
||||
self.assertTrue(abs(word_error_rate - 0.2666666667) < 1e-6)
|
||||
|
||||
def test_wer_5(self):
|
||||
ref = 'all the morning they trudged up the mountain path and at noon '\
|
||||
'unc and ojo sat on a fallen tree trunk and ate the last of '\
|
||||
'the bread which the old munchkin had placed in his pocket'
|
||||
hyp = 'all the morning they trudged up the mountain path and at noon '\
|
||||
'unc in ojo sat on a fallen tree trunk and ate the last of '\
|
||||
'the bread which the old munchkin had placed in his pocket'
|
||||
word_error_rate = error_rate.wer(ref, hyp)
|
||||
self.assertTrue(abs(word_error_rate - 0.027027027) < 1e-6)
|
||||
|
||||
def test_wer_6(self):
|
||||
ref = 'i UM the PHONE IS i LEFT THE portable PHONE UPSTAIRS last night'
|
||||
word_error_rate = error_rate.wer(ref, ref)
|
||||
self.assertEqual(word_error_rate, 0.0)
|
||||
|
||||
def test_wer_7(self):
|
||||
ref = ' '
|
||||
hyp = 'Hypothesis sentence'
|
||||
with self.assertRaises(ValueError):
|
||||
word_error_rate = error_rate.wer(ref, hyp)
|
||||
|
||||
def test_cer_1(self):
|
||||
ref = 'werewolf'
|
||||
hyp = 'weae wolf'
|
||||
char_error_rate = error_rate.cer(ref, hyp)
|
||||
self.assertTrue(abs(char_error_rate - 0.25) < 1e-6)
|
||||
|
||||
def test_cer_2(self):
|
||||
ref = 'werewolf'
|
||||
hyp = 'weae wolf'
|
||||
char_error_rate = error_rate.cer(ref, hyp, remove_space=True)
|
||||
self.assertTrue(abs(char_error_rate - 0.125) < 1e-6)
|
||||
|
||||
def test_cer_3(self):
|
||||
ref = 'were wolf'
|
||||
hyp = 'were wolf'
|
||||
char_error_rate = error_rate.cer(ref, hyp)
|
||||
self.assertTrue(abs(char_error_rate - 0.0) < 1e-6)
|
||||
|
||||
def test_cer_4(self):
|
||||
ref = 'werewolf'
|
||||
char_error_rate = error_rate.cer(ref, ref)
|
||||
self.assertEqual(char_error_rate, 0.0)
|
||||
|
||||
def test_cer_5(self):
|
||||
ref = u'我是中国人'
|
||||
hyp = u'我是 美洲人'
|
||||
char_error_rate = error_rate.cer(ref, hyp)
|
||||
self.assertTrue(abs(char_error_rate - 0.6) < 1e-6)
|
||||
|
||||
def test_cer_6(self):
|
||||
ref = u'我 是 中 国 人'
|
||||
hyp = u'我 是 美 洲 人'
|
||||
char_error_rate = error_rate.cer(ref, hyp, remove_space=True)
|
||||
self.assertTrue(abs(char_error_rate - 0.4) < 1e-6)
|
||||
|
||||
def test_cer_7(self):
|
||||
ref = u'我是中国人'
|
||||
char_error_rate = error_rate.cer(ref, ref)
|
||||
self.assertFalse(char_error_rate, 0.0)
|
||||
|
||||
def test_cer_8(self):
|
||||
ref = ''
|
||||
hyp = 'Hypothesis'
|
||||
with self.assertRaises(ValueError):
|
||||
char_error_rate = error_rate.cer(ref, hyp)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,47 @@
|
||||
"""Contains common utility functions."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import distutils.util
|
||||
|
||||
|
||||
def print_arguments(args):
|
||||
"""Print argparse's arguments.
|
||||
|
||||
Usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("name", default="Jonh", type=str, help="User name.")
|
||||
args = parser.parse_args()
|
||||
print_arguments(args)
|
||||
|
||||
:param args: Input argparse.Namespace for printing.
|
||||
:type args: argparse.Namespace
|
||||
"""
|
||||
print("----------- Configuration Arguments -----------")
|
||||
for arg, value in sorted(vars(args).iteritems()):
|
||||
print("%s: %s" % (arg, value))
|
||||
print("------------------------------------------------")
|
||||
|
||||
|
||||
def add_arguments(argname, type, default, help, argparser, **kwargs):
|
||||
"""Add argparse's argument.
|
||||
|
||||
Usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
add_argument("name", str, "Jonh", "User name.", parser)
|
||||
args = parser.parse_args()
|
||||
"""
|
||||
type = distutils.util.strtobool if type == bool else type
|
||||
argparser.add_argument(
|
||||
"--" + argname,
|
||||
default=default,
|
||||
type=type,
|
||||
help=help + ' Default: %(default)s.',
|
||||
**kwargs)
|
@ -0,0 +1,23 @@
|
||||
download() {
|
||||
URL=$1
|
||||
MD5=$2
|
||||
TARGET=$3
|
||||
|
||||
if [ -e $TARGET ]; then
|
||||
md5_result=`md5sum $TARGET | awk -F[' '] '{print $1}'`
|
||||
if [ $MD5 == $md5_result ]; then
|
||||
echo "$TARGET already exists, download skipped."
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
|
||||
wget -c $URL -O "$TARGET"
|
||||
if [ $? -ne 0 ]; then
|
||||
return 1
|
||||
fi
|
||||
|
||||
md5_result=`md5sum $TARGET | awk -F[' '] '{print $1}'`
|
||||
if [ ! $MD5 == $md5_result ]; then
|
||||
return 1
|
||||
fi
|
||||
}
|
Loading…
Reference in new issue