commit
d2bdd254a3
@ -0,0 +1,29 @@
|
||||
# This file is used by clang-format to autoformat paddle source code
|
||||
#
|
||||
# The clang-format is part of llvm toolchain.
|
||||
# It need to install llvm and clang to format source code style.
|
||||
#
|
||||
# The basic usage is,
|
||||
# clang-format -i -style=file PATH/TO/SOURCE/CODE
|
||||
#
|
||||
# The -style=file implicit use ".clang-format" file located in one of
|
||||
# parent directory.
|
||||
# The -i means inplace change.
|
||||
#
|
||||
# The document of clang-format is
|
||||
# http://clang.llvm.org/docs/ClangFormat.html
|
||||
# http://clang.llvm.org/docs/ClangFormatStyleOptions.html
|
||||
---
|
||||
Language: Cpp
|
||||
BasedOnStyle: Google
|
||||
IndentWidth: 2
|
||||
TabWidth: 2
|
||||
ContinuationIndentWidth: 4
|
||||
MaxEmptyLinesToKeep: 2
|
||||
AccessModifierOffset: -2 # The private/protected/public has no indent in class
|
||||
Standard: Cpp11
|
||||
AllowAllParametersOfDeclarationOnNextLine: true
|
||||
BinPackParameters: false
|
||||
BinPackArguments: false
|
||||
...
|
||||
|
@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
readonly VERSION="3.9"
|
||||
|
||||
version=$(clang-format -version)
|
||||
|
||||
if ! [[ $version == *"$VERSION"* ]]; then
|
||||
echo "clang-format version check failed."
|
||||
echo "a version contains '$VERSION' is needed, but get '$version'"
|
||||
echo "you can install the right version, and make an soft-link to '\$PATH' env"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
clang-format $@
|
@ -0,0 +1,2 @@
|
||||
.DS_Store
|
||||
*.pyc
|
@ -0,0 +1,35 @@
|
||||
- repo: https://github.com/pre-commit/mirrors-yapf.git
|
||||
sha: v0.16.0
|
||||
hooks:
|
||||
- id: yapf
|
||||
files: \.py$
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
sha: a11d9314b22d8f8c7556443875b731ef05965464
|
||||
hooks:
|
||||
- id: check-merge-conflict
|
||||
- id: check-symlinks
|
||||
- id: detect-private-key
|
||||
files: (?!.*paddle)^.*$
|
||||
- id: end-of-file-fixer
|
||||
files: \.md$
|
||||
- id: trailing-whitespace
|
||||
files: \.md$
|
||||
- repo: https://github.com/Lucas-C/pre-commit-hooks
|
||||
sha: v1.0.1
|
||||
hooks:
|
||||
- id: forbid-crlf
|
||||
files: \.md$
|
||||
- id: remove-crlf
|
||||
files: \.md$
|
||||
- id: forbid-tabs
|
||||
files: \.md$
|
||||
- id: remove-tabs
|
||||
files: \.md$
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: clang-format
|
||||
name: clang-format
|
||||
description: Format files with ClangFormat
|
||||
entry: bash .clang_format.hook -i
|
||||
language: system
|
||||
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
|
@ -0,0 +1,3 @@
|
||||
[style]
|
||||
based_on_style = pep8
|
||||
column_limit = 80
|
@ -0,0 +1,34 @@
|
||||
language: cpp
|
||||
cache: ccache
|
||||
sudo: required
|
||||
dist: trusty
|
||||
services:
|
||||
- docker
|
||||
os:
|
||||
- linux
|
||||
env:
|
||||
- JOB=PRE_COMMIT
|
||||
|
||||
addons:
|
||||
apt:
|
||||
packages:
|
||||
- git
|
||||
- python
|
||||
- python-pip
|
||||
- python2.7-dev
|
||||
|
||||
before_install:
|
||||
- sudo pip install -U virtualenv pre-commit pip
|
||||
- docker pull paddlepaddle/paddle:latest
|
||||
|
||||
script:
|
||||
- exit_code=0
|
||||
- .travis/precommit.sh || exit_code=$(( exit_code | $? ))
|
||||
- docker run -i --rm -v "$PWD:/py_unittest" paddlepaddle/paddle:latest /bin/bash -c
|
||||
'cd /py_unittest; sh .travis/unittest.sh' || exit_code=$(( exit_code | $? ))
|
||||
exit $exit_code
|
||||
|
||||
notifications:
|
||||
email:
|
||||
on_success: change
|
||||
on_failure: always
|
@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
function abort(){
|
||||
echo "Your commit not fit PaddlePaddle code style" 1>&2
|
||||
echo "Please use pre-commit scripts to auto-format your code" 1>&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
trap 'abort' 0
|
||||
set -e
|
||||
cd `dirname $0`
|
||||
cd ..
|
||||
export PATH=/usr/bin:$PATH
|
||||
pre-commit install
|
||||
|
||||
if ! pre-commit run -a ; then
|
||||
ls -lh
|
||||
git diff --exit-code
|
||||
exit 1
|
||||
fi
|
||||
|
||||
trap : 0
|
@ -0,0 +1,29 @@
|
||||
#!/bin/bash
|
||||
|
||||
abort(){
|
||||
echo "Run unittest failed" 1>&2
|
||||
echo "Please check your code" 1>&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
unittest(){
|
||||
cd $1 > /dev/null
|
||||
if [ -f "setup.sh" ]; then
|
||||
sh setup.sh
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
|
||||
fi
|
||||
if [ $? != 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
find . -name 'tests' -type d -print0 | \
|
||||
xargs -0 -I{} -n1 bash -c \
|
||||
'python -m unittest discover -v -s {}'
|
||||
cd - > /dev/null
|
||||
}
|
||||
|
||||
trap 'abort' 0
|
||||
set -e
|
||||
|
||||
unittest .
|
||||
|
||||
trap : 0
|
@ -0,0 +1,17 @@
|
||||
"""Set up paths for DS2"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os.path
|
||||
import sys
|
||||
|
||||
|
||||
def add_path(path):
|
||||
if path not in sys.path:
|
||||
sys.path.insert(0, path)
|
||||
|
||||
|
||||
this_dir = os.path.dirname(__file__)
|
||||
proj_path = os.path.join(this_dir, '..')
|
||||
add_path(proj_path)
|
@ -0,0 +1,29 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
TRAIN_MANIFEST="cloud/cloud_manifests/cloud.manifest.train"
|
||||
DEV_MANIFEST="cloud/cloud_manifests/cloud.manifest.dev"
|
||||
CLOUD_MODEL_DIR="./checkpoints"
|
||||
BATCH_SIZE=512
|
||||
NUM_GPU=8
|
||||
NUM_NODE=1
|
||||
IS_LOCAL="True"
|
||||
|
||||
JOB_NAME=deepspeech-`date +%Y%m%d%H%M%S`
|
||||
DS2_PATH=${PWD%/*}
|
||||
cp -f pcloud_train.sh ${DS2_PATH}
|
||||
|
||||
paddlecloud submit \
|
||||
-image bootstrapper:5000/paddlepaddle/pcloud_ds2:latest \
|
||||
-jobname ${JOB_NAME} \
|
||||
-cpu ${NUM_GPU} \
|
||||
-gpu ${NUM_GPU} \
|
||||
-memory 64Gi \
|
||||
-parallelism ${NUM_NODE} \
|
||||
-pscpu 1 \
|
||||
-pservers 1 \
|
||||
-psmemory 64Gi \
|
||||
-passes 1 \
|
||||
-entry "sh pcloud_train.sh ${TRAIN_MANIFEST} ${DEV_MANIFEST} ${CLOUD_MODEL_DIR} ${NUM_GPU} ${BATCH_SIZE} ${IS_LOCAL}" \
|
||||
${DS2_PATH}
|
||||
|
||||
rm ${DS2_PATH}/pcloud_train.sh
|
@ -0,0 +1,46 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
TRAIN_MANIFEST=$1
|
||||
DEV_MANIFEST=$2
|
||||
MODEL_PATH=$3
|
||||
NUM_GPU=$4
|
||||
BATCH_SIZE=$5
|
||||
IS_LOCAL=$6
|
||||
|
||||
python ./cloud/split_data.py \
|
||||
--in_manifest_path=${TRAIN_MANIFEST} \
|
||||
--out_manifest_path='/local.manifest.train'
|
||||
|
||||
python ./cloud/split_data.py \
|
||||
--in_manifest_path=${DEV_MANIFEST} \
|
||||
--out_manifest_path='/local.manifest.dev'
|
||||
|
||||
mkdir ./logs
|
||||
|
||||
python -u train.py \
|
||||
--batch_size=${BATCH_SIZE} \
|
||||
--trainer_count=${NUM_GPU} \
|
||||
--num_passes=200 \
|
||||
--num_proc_data=${NUM_GPU} \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--num_iter_print=100 \
|
||||
--learning_rate=5e-4 \
|
||||
--max_duration=27.0 \
|
||||
--min_duration=0.0 \
|
||||
--use_sortagrad=True \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--is_local=${IS_LOCAL} \
|
||||
--share_rnn_weights=True \
|
||||
--train_manifest='/local.manifest.train' \
|
||||
--dev_manifest='/local.manifest.dev' \
|
||||
--mean_std_path='data/librispeech/mean_std.npz' \
|
||||
--vocab_path='data/librispeech/vocab.txt' \
|
||||
--output_model_dir='./checkpoints' \
|
||||
--output_model_dir=${MODEL_PATH} \
|
||||
--augment_conf_path='conf/augmentation.config' \
|
||||
--specgram_type='linear' \
|
||||
--shuffle_method='batch_shuffle_clipped' \
|
||||
2>&1 | tee ./logs/train.log
|
@ -0,0 +1,22 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
mkdir cloud_manifests
|
||||
|
||||
IN_MANIFESTS="../data/librispeech/manifest.train ../data/librispeech/manifest.dev-clean ../data/librispeech/manifest.test-clean"
|
||||
OUT_MANIFESTS="cloud_manifests/cloud.manifest.train cloud_manifests/cloud.manifest.dev cloud_manifests/cloud.manifest.test"
|
||||
CLOUD_DATA_DIR="/pfs/dlnel/home/USERNAME/deepspeech2/data/librispeech"
|
||||
NUM_SHARDS=50
|
||||
|
||||
python upload_data.py \
|
||||
--in_manifest_paths ${IN_MANIFESTS} \
|
||||
--out_manifest_paths ${OUT_MANIFESTS} \
|
||||
--cloud_data_dir ${CLOUD_DATA_DIR} \
|
||||
--num_shards ${NUM_SHARDS}
|
||||
|
||||
if [ $? -ne 0 ]
|
||||
then
|
||||
echo "Upload Data Failed!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "All Done."
|
@ -0,0 +1,41 @@
|
||||
"""This tool is used for splitting data into each node of
|
||||
paddlecloud. This script should be called in paddlecloud.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--in_manifest_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Input manifest path for all nodes.")
|
||||
parser.add_argument(
|
||||
"--out_manifest_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Output manifest file path for current node.")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def split_data(in_manifest_path, out_manifest_path):
|
||||
with open("/trainer_id", "r") as f:
|
||||
trainer_id = int(f.readline()[:-1])
|
||||
with open("/trainer_count", "r") as f:
|
||||
trainer_count = int(f.readline()[:-1])
|
||||
|
||||
out_manifest = []
|
||||
for index, json_line in enumerate(open(in_manifest_path, 'r')):
|
||||
if (index % trainer_count) == trainer_id:
|
||||
out_manifest.append("%s\n" % json_line.strip())
|
||||
with open(out_manifest_path, 'w') as f:
|
||||
f.writelines(out_manifest)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
split_data(args.in_manifest_path, args.out_manifest_path)
|
@ -0,0 +1,129 @@
|
||||
"""This script is for uploading data for DeepSpeech2 training on paddlecloud.
|
||||
|
||||
Steps:
|
||||
1. Read original manifests and extract local sound files.
|
||||
2. Tar all local sound files into multiple tar files and upload them.
|
||||
3. Modify original manifests with updated paths in cloud filesystem.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import os
|
||||
import tarfile
|
||||
import sys
|
||||
import argparse
|
||||
import shutil
|
||||
from subprocess import call
|
||||
import _init_paths
|
||||
from data_utils.utils import read_manifest
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--in_manifest_paths",
|
||||
default=[
|
||||
"../datasets/manifest.train", "../datasets/manifest.dev",
|
||||
"../datasets/manifest.test"
|
||||
],
|
||||
type=str,
|
||||
nargs='+',
|
||||
help="Local filepaths of input manifests to load, pack and upload."
|
||||
"(default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--out_manifest_paths",
|
||||
default=[
|
||||
"./cloud.manifest.train", "./cloud.manifest.dev",
|
||||
"./cloud.manifest.test"
|
||||
],
|
||||
type=str,
|
||||
nargs='+',
|
||||
help="Local filepaths of modified manifests to write to. "
|
||||
"(default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--cloud_data_dir",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Destination directory on paddlecloud to upload data to.")
|
||||
parser.add_argument(
|
||||
"--num_shards",
|
||||
default=10,
|
||||
type=int,
|
||||
help="Number of parts to split data to. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--local_tmp_dir",
|
||||
default="./tmp/",
|
||||
type=str,
|
||||
help="Local directory for storing temporary data. (default: %(default)s)")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def upload_data(in_manifest_path_list, out_manifest_path_list, local_tmp_dir,
|
||||
upload_tar_dir, num_shards):
|
||||
"""Extract and pack sound files listed in the manifest files into multple
|
||||
tar files and upload them to padldecloud. Besides, generate new manifest
|
||||
files with updated paths in paddlecloud.
|
||||
"""
|
||||
# compute total audio number
|
||||
total_line = 0
|
||||
for manifest_path in in_manifest_path_list:
|
||||
with open(manifest_path, 'r') as f:
|
||||
total_line += len(f.readlines())
|
||||
line_per_tar = (total_line // num_shards) + 1
|
||||
|
||||
# pack and upload shard by shard
|
||||
line_count, tar_file = 0, None
|
||||
for manifest_path, out_manifest_path in zip(in_manifest_path_list,
|
||||
out_manifest_path_list):
|
||||
manifest = read_manifest(manifest_path)
|
||||
out_manifest = []
|
||||
for json_data in manifest:
|
||||
sound_filepath = json_data['audio_filepath']
|
||||
sound_filename = os.path.basename(sound_filepath)
|
||||
if line_count % line_per_tar == 0:
|
||||
if tar_file != None:
|
||||
tar_file.close()
|
||||
pcloud_cp(tar_path, upload_tar_dir)
|
||||
os.remove(tar_path)
|
||||
tar_name = 'part-%s-of-%s.tar' % (
|
||||
str(line_count // line_per_tar).zfill(5),
|
||||
str(num_shards).zfill(5))
|
||||
tar_path = os.path.join(local_tmp_dir, tar_name)
|
||||
tar_file = tarfile.open(tar_path, 'w')
|
||||
tar_file.add(sound_filepath, arcname=sound_filename)
|
||||
line_count += 1
|
||||
json_data['audio_filepath'] = "tar:%s#%s" % (
|
||||
os.path.join(upload_tar_dir, tar_name), sound_filename)
|
||||
out_manifest.append("%s\n" % json.dumps(json_data))
|
||||
with open(out_manifest_path, 'w') as f:
|
||||
f.writelines(out_manifest)
|
||||
pcloud_cp(out_manifest_path, upload_tar_dir)
|
||||
tar_file.close()
|
||||
pcloud_cp(tar_path, upload_tar_dir)
|
||||
os.remove(tar_path)
|
||||
|
||||
|
||||
def pcloud_mkdir(dir):
|
||||
"""Make directory in PaddleCloud filesystem.
|
||||
"""
|
||||
if call(['paddlecloud', 'mkdir', dir]) != 0:
|
||||
raise IOError("PaddleCloud mkdir failed: %s." % dir)
|
||||
|
||||
|
||||
def pcloud_cp(src, dst):
|
||||
"""Copy src from local filesytem to dst in PaddleCloud filesystem,
|
||||
or downlowd src from PaddleCloud filesystem to dst in local filesystem.
|
||||
"""
|
||||
if call(['paddlecloud', 'cp', src, dst]) != 0:
|
||||
raise IOError("PaddleCloud cp failed: from [%s] to [%s]." % (src, dst))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if not os.path.exists(args.local_tmp_dir):
|
||||
os.makedirs(args.local_tmp_dir)
|
||||
pcloud_mkdir(args.cloud_data_dir)
|
||||
|
||||
upload_data(args.in_manifest_paths, args.out_manifest_paths,
|
||||
args.local_tmp_dir, args.cloud_data_dir, args.num_shards)
|
||||
|
||||
shutil.rmtree(args.local_tmp_dir)
|
@ -0,0 +1,8 @@
|
||||
[
|
||||
{
|
||||
"type": "shift",
|
||||
"params": {"min_shift_ms": -5,
|
||||
"max_shift_ms": 5},
|
||||
"prob": 1.0
|
||||
}
|
||||
]
|
@ -0,0 +1,39 @@
|
||||
[
|
||||
{
|
||||
"type": "noise",
|
||||
"params": {"min_snr_dB": 40,
|
||||
"max_snr_dB": 50,
|
||||
"noise_manifest_path": "datasets/manifest.noise"},
|
||||
"prob": 0.6
|
||||
},
|
||||
{
|
||||
"type": "impulse",
|
||||
"params": {"impulse_manifest_path": "datasets/manifest.impulse"},
|
||||
"prob": 0.5
|
||||
},
|
||||
{
|
||||
"type": "speed",
|
||||
"params": {"min_speed_rate": 0.95,
|
||||
"max_speed_rate": 1.05},
|
||||
"prob": 0.5
|
||||
},
|
||||
{
|
||||
"type": "shift",
|
||||
"params": {"min_shift_ms": -5,
|
||||
"max_shift_ms": 5},
|
||||
"prob": 1.0
|
||||
},
|
||||
{
|
||||
"type": "volume",
|
||||
"params": {"min_gain_dBFS": -10,
|
||||
"max_gain_dBFS": 10},
|
||||
"prob": 0.0
|
||||
},
|
||||
{
|
||||
"type": "bayesian_normal",
|
||||
"params": {"target_db": -20,
|
||||
"prior_db": -20,
|
||||
"prior_samples": 100},
|
||||
"prob": 0.0
|
||||
}
|
||||
]
|
@ -0,0 +1,110 @@
|
||||
"""Prepare Aishell mandarin dataset
|
||||
|
||||
Download, unpack and create manifest files.
|
||||
Manifest file is a json-format file with each line containing the
|
||||
meta data (i.e. audio filepath, transcript and audio duration)
|
||||
of each audio file in the data set.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import codecs
|
||||
import soundfile
|
||||
import json
|
||||
import argparse
|
||||
from data_utils.utility import download, unpack
|
||||
|
||||
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
|
||||
|
||||
URL_ROOT = 'http://www.openslr.org/resources/33'
|
||||
DATA_URL = URL_ROOT + '/data_aishell.tgz'
|
||||
MD5_DATA = '2f494334227864a8a8fec932999db9d8'
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--target_dir",
|
||||
default=DATA_HOME + "/Aishell",
|
||||
type=str,
|
||||
help="Directory to save the dataset. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--manifest_prefix",
|
||||
default="manifest",
|
||||
type=str,
|
||||
help="Filepath prefix for output manifests. (default: %(default)s)")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def create_manifest(data_dir, manifest_path_prefix):
|
||||
print("Creating manifest %s ..." % manifest_path_prefix)
|
||||
json_lines = []
|
||||
transcript_path = os.path.join(data_dir, 'transcript',
|
||||
'aishell_transcript_v0.8.txt')
|
||||
transcript_dict = {}
|
||||
for line in codecs.open(transcript_path, 'r', 'utf-8'):
|
||||
line = line.strip()
|
||||
if line == '': continue
|
||||
audio_id, text = line.split(' ', 1)
|
||||
# remove withespace
|
||||
text = ''.join(text.split())
|
||||
transcript_dict[audio_id] = text
|
||||
|
||||
data_types = ['train', 'dev', 'test']
|
||||
for type in data_types:
|
||||
del json_lines[:]
|
||||
audio_dir = os.path.join(data_dir, 'wav', type)
|
||||
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
|
||||
for fname in filelist:
|
||||
audio_path = os.path.join(subfolder, fname)
|
||||
audio_id = fname[:-4]
|
||||
# if no transcription for audio then skipped
|
||||
if audio_id not in transcript_dict:
|
||||
continue
|
||||
audio_data, samplerate = soundfile.read(audio_path)
|
||||
duration = float(len(audio_data) / samplerate)
|
||||
text = transcript_dict[audio_id]
|
||||
json_lines.append(
|
||||
json.dumps(
|
||||
{
|
||||
'audio_filepath': audio_path,
|
||||
'duration': duration,
|
||||
'text': text
|
||||
},
|
||||
ensure_ascii=False))
|
||||
manifest_path = manifest_path_prefix + '.' + type
|
||||
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
|
||||
for line in json_lines:
|
||||
fout.write(line + '\n')
|
||||
|
||||
|
||||
def prepare_dataset(url, md5sum, target_dir, manifest_path):
|
||||
"""Download, unpack and create manifest file."""
|
||||
data_dir = os.path.join(target_dir, 'data_aishell')
|
||||
if not os.path.exists(data_dir):
|
||||
filepath = download(url, md5sum, target_dir)
|
||||
unpack(filepath, target_dir)
|
||||
# unpack all audio tar files
|
||||
audio_dir = os.path.join(data_dir, 'wav')
|
||||
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
|
||||
for ftar in filelist:
|
||||
unpack(os.path.join(subfolder, ftar), subfolder, True)
|
||||
else:
|
||||
print("Skip downloading and unpacking. Data already exists in %s." %
|
||||
target_dir)
|
||||
create_manifest(data_dir, manifest_path)
|
||||
|
||||
|
||||
def main():
|
||||
if args.target_dir.startswith('~'):
|
||||
args.target_dir = os.path.expanduser(args.target_dir)
|
||||
|
||||
prepare_dataset(
|
||||
url=DATA_URL,
|
||||
md5sum=MD5_DATA,
|
||||
target_dir=args.target_dir,
|
||||
manifest_path=args.manifest_prefix)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,148 @@
|
||||
"""Prepare Librispeech ASR datasets.
|
||||
|
||||
Download, unpack and create manifest files.
|
||||
Manifest file is a json-format file with each line containing the
|
||||
meta data (i.e. audio filepath, transcript and audio duration)
|
||||
of each audio file in the data set.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import distutils.util
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import soundfile
|
||||
import json
|
||||
import codecs
|
||||
from data_utils.utility import download, unpack
|
||||
|
||||
URL_ROOT = "http://www.openslr.org/resources/12"
|
||||
URL_TEST_CLEAN = URL_ROOT + "/test-clean.tar.gz"
|
||||
URL_TEST_OTHER = URL_ROOT + "/test-other.tar.gz"
|
||||
URL_DEV_CLEAN = URL_ROOT + "/dev-clean.tar.gz"
|
||||
URL_DEV_OTHER = URL_ROOT + "/dev-other.tar.gz"
|
||||
URL_TRAIN_CLEAN_100 = URL_ROOT + "/train-clean-100.tar.gz"
|
||||
URL_TRAIN_CLEAN_360 = URL_ROOT + "/train-clean-360.tar.gz"
|
||||
URL_TRAIN_OTHER_500 = URL_ROOT + "/train-other-500.tar.gz"
|
||||
|
||||
MD5_TEST_CLEAN = "32fa31d27d2e1cad72775fee3f4849a9"
|
||||
MD5_TEST_OTHER = "fb5a50374b501bb3bac4815ee91d3135"
|
||||
MD5_DEV_CLEAN = "42e2234ba48799c1f50f24a7926300a1"
|
||||
MD5_DEV_OTHER = "c8d0bcc9cca99d4f8b62fcc847357931"
|
||||
MD5_TRAIN_CLEAN_100 = "2a93770f6d5c6c964bc36631d331a522"
|
||||
MD5_TRAIN_CLEAN_360 = "c0e676e450a7ff2f54aeade5171606fa"
|
||||
MD5_TRAIN_OTHER_500 = "d1a0fd59409feb2c614ce4d30c387708"
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--target_dir",
|
||||
default='~/.cache/paddle/dataset/speech/libri',
|
||||
type=str,
|
||||
help="Directory to save the dataset. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--manifest_prefix",
|
||||
default="manifest",
|
||||
type=str,
|
||||
help="Filepath prefix for output manifests. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--full_download",
|
||||
default="True",
|
||||
type=distutils.util.strtobool,
|
||||
help="Download all datasets for Librispeech."
|
||||
" If False, only download a minimal requirement (test-clean, dev-clean"
|
||||
" train-clean-100). (default: %(default)s)")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def create_manifest(data_dir, manifest_path):
|
||||
"""Create a manifest json file summarizing the data set, with each line
|
||||
containing the meta data (i.e. audio filepath, transcription text, audio
|
||||
duration) of each audio file within the data set.
|
||||
"""
|
||||
print("Creating manifest %s ..." % manifest_path)
|
||||
json_lines = []
|
||||
for subfolder, _, filelist in sorted(os.walk(data_dir)):
|
||||
text_filelist = [
|
||||
filename for filename in filelist if filename.endswith('trans.txt')
|
||||
]
|
||||
if len(text_filelist) > 0:
|
||||
text_filepath = os.path.join(data_dir, subfolder, text_filelist[0])
|
||||
for line in open(text_filepath):
|
||||
segments = line.strip().split()
|
||||
text = ' '.join(segments[1:]).lower()
|
||||
audio_filepath = os.path.join(data_dir, subfolder,
|
||||
segments[0] + '.flac')
|
||||
audio_data, samplerate = soundfile.read(audio_filepath)
|
||||
duration = float(len(audio_data)) / samplerate
|
||||
json_lines.append(
|
||||
json.dumps({
|
||||
'audio_filepath': audio_filepath,
|
||||
'duration': duration,
|
||||
'text': text
|
||||
}))
|
||||
with codecs.open(manifest_path, 'w', 'utf-8') as out_file:
|
||||
for line in json_lines:
|
||||
out_file.write(line + '\n')
|
||||
|
||||
|
||||
def prepare_dataset(url, md5sum, target_dir, manifest_path):
|
||||
"""Download, unpack and create summmary manifest file.
|
||||
"""
|
||||
if not os.path.exists(os.path.join(target_dir, "LibriSpeech")):
|
||||
# download
|
||||
filepath = download(url, md5sum, target_dir)
|
||||
# unpack
|
||||
unpack(filepath, target_dir)
|
||||
else:
|
||||
print("Skip downloading and unpacking. Data already exists in %s." %
|
||||
target_dir)
|
||||
# create manifest json file
|
||||
create_manifest(target_dir, manifest_path)
|
||||
|
||||
|
||||
def main():
|
||||
if args.target_dir.startswith('~'):
|
||||
args.target_dir = os.path.expanduser(args.target_dir)
|
||||
|
||||
prepare_dataset(
|
||||
url=URL_TEST_CLEAN,
|
||||
md5sum=MD5_TEST_CLEAN,
|
||||
target_dir=os.path.join(args.target_dir, "test-clean"),
|
||||
manifest_path=args.manifest_prefix + ".test-clean")
|
||||
prepare_dataset(
|
||||
url=URL_DEV_CLEAN,
|
||||
md5sum=MD5_DEV_CLEAN,
|
||||
target_dir=os.path.join(args.target_dir, "dev-clean"),
|
||||
manifest_path=args.manifest_prefix + ".dev-clean")
|
||||
if args.full_download:
|
||||
prepare_dataset(
|
||||
url=URL_TRAIN_CLEAN_100,
|
||||
md5sum=MD5_TRAIN_CLEAN_100,
|
||||
target_dir=os.path.join(args.target_dir, "train-clean-100"),
|
||||
manifest_path=args.manifest_prefix + ".train-clean-100")
|
||||
prepare_dataset(
|
||||
url=URL_TEST_OTHER,
|
||||
md5sum=MD5_TEST_OTHER,
|
||||
target_dir=os.path.join(args.target_dir, "test-other"),
|
||||
manifest_path=args.manifest_prefix + ".test-other")
|
||||
prepare_dataset(
|
||||
url=URL_DEV_OTHER,
|
||||
md5sum=MD5_DEV_OTHER,
|
||||
target_dir=os.path.join(args.target_dir, "dev-other"),
|
||||
manifest_path=args.manifest_prefix + ".dev-other")
|
||||
prepare_dataset(
|
||||
url=URL_TRAIN_CLEAN_360,
|
||||
md5sum=MD5_TRAIN_CLEAN_360,
|
||||
target_dir=os.path.join(args.target_dir, "train-clean-360"),
|
||||
manifest_path=args.manifest_prefix + ".train-clean-360")
|
||||
prepare_dataset(
|
||||
url=URL_TRAIN_OTHER_500,
|
||||
md5sum=MD5_TRAIN_OTHER_500,
|
||||
target_dir=os.path.join(args.target_dir, "train-other-500"),
|
||||
manifest_path=args.manifest_prefix + ".train-other-500")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,128 @@
|
||||
"""Prepare CHiME3 background data.
|
||||
|
||||
Download, unpack and create manifest files.
|
||||
Manifest file is a json-format file with each line containing the
|
||||
meta data (i.e. audio filepath, transcript and audio duration)
|
||||
of each audio file in the data set.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import distutils.util
|
||||
import os
|
||||
import wget
|
||||
import zipfile
|
||||
import argparse
|
||||
import soundfile
|
||||
import json
|
||||
from paddle.v2.dataset.common import md5file
|
||||
|
||||
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
|
||||
|
||||
URL = "https://d4s.myairbridge.com/packagev2/AG0Y3DNBE5IWRRTV/?dlid=W19XG7T0NNHB027139H0EQ"
|
||||
MD5 = "c3ff512618d7a67d4f85566ea1bc39ec"
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--target_dir",
|
||||
default=DATA_HOME + "/chime3_background",
|
||||
type=str,
|
||||
help="Directory to save the dataset. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--manifest_filepath",
|
||||
default="manifest.chime3.background",
|
||||
type=str,
|
||||
help="Filepath for output manifests. (default: %(default)s)")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def download(url, md5sum, target_dir, filename=None):
|
||||
"""Download file from url to target_dir, and check md5sum."""
|
||||
if filename == None:
|
||||
filename = url.split("/")[-1]
|
||||
if not os.path.exists(target_dir): os.makedirs(target_dir)
|
||||
filepath = os.path.join(target_dir, filename)
|
||||
if not (os.path.exists(filepath) and md5file(filepath) == md5sum):
|
||||
print("Downloading %s ..." % url)
|
||||
wget.download(url, target_dir)
|
||||
print("\nMD5 Chesksum %s ..." % filepath)
|
||||
if not md5file(filepath) == md5sum:
|
||||
raise RuntimeError("MD5 checksum failed.")
|
||||
else:
|
||||
print("File exists, skip downloading. (%s)" % filepath)
|
||||
return filepath
|
||||
|
||||
|
||||
def unpack(filepath, target_dir):
|
||||
"""Unpack the file to the target_dir."""
|
||||
print("Unpacking %s ..." % filepath)
|
||||
if filepath.endswith('.zip'):
|
||||
zip = zipfile.ZipFile(filepath, 'r')
|
||||
zip.extractall(target_dir)
|
||||
zip.close()
|
||||
elif filepath.endswith('.tar') or filepath.endswith('.tar.gz'):
|
||||
tar = zipfile.open(filepath)
|
||||
tar.extractall(target_dir)
|
||||
tar.close()
|
||||
else:
|
||||
raise ValueError("File format is not supported for unpacking.")
|
||||
|
||||
|
||||
def create_manifest(data_dir, manifest_path):
|
||||
"""Create a manifest json file summarizing the data set, with each line
|
||||
containing the meta data (i.e. audio filepath, transcription text, audio
|
||||
duration) of each audio file within the data set.
|
||||
"""
|
||||
print("Creating manifest %s ..." % manifest_path)
|
||||
json_lines = []
|
||||
for subfolder, _, filelist in sorted(os.walk(data_dir)):
|
||||
for filename in filelist:
|
||||
if filename.endswith('.wav'):
|
||||
filepath = os.path.join(data_dir, subfolder, filename)
|
||||
audio_data, samplerate = soundfile.read(filepath)
|
||||
duration = float(len(audio_data)) / samplerate
|
||||
json_lines.append(
|
||||
json.dumps({
|
||||
'audio_filepath': filepath,
|
||||
'duration': duration,
|
||||
'text': ''
|
||||
}))
|
||||
with open(manifest_path, 'w') as out_file:
|
||||
for line in json_lines:
|
||||
out_file.write(line + '\n')
|
||||
|
||||
|
||||
def prepare_chime3(url, md5sum, target_dir, manifest_path):
|
||||
"""Download, unpack and create summmary manifest file."""
|
||||
if not os.path.exists(os.path.join(target_dir, "CHiME3")):
|
||||
# download
|
||||
filepath = download(url, md5sum, target_dir,
|
||||
"myairbridge-AG0Y3DNBE5IWRRTV.zip")
|
||||
# unpack
|
||||
unpack(filepath, target_dir)
|
||||
unpack(
|
||||
os.path.join(target_dir, 'CHiME3_background_bus.zip'), target_dir)
|
||||
unpack(
|
||||
os.path.join(target_dir, 'CHiME3_background_caf.zip'), target_dir)
|
||||
unpack(
|
||||
os.path.join(target_dir, 'CHiME3_background_ped.zip'), target_dir)
|
||||
unpack(
|
||||
os.path.join(target_dir, 'CHiME3_background_str.zip'), target_dir)
|
||||
else:
|
||||
print("Skip downloading and unpacking. Data already exists in %s." %
|
||||
target_dir)
|
||||
# create manifest json file
|
||||
create_manifest(target_dir, manifest_path)
|
||||
|
||||
|
||||
def main():
|
||||
prepare_chime3(
|
||||
url=URL,
|
||||
md5sum=MD5,
|
||||
target_dir=args.target_dir,
|
||||
manifest_path=args.manifest_filepath)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,16 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
# download data, generate manifests
|
||||
PYTHONPATH=../../:$PYTHONPATH python voxforge.py \
|
||||
--manifest_prefix='./manifest' \
|
||||
--target_dir='~/.cache/paddle/dataset/speech/VoxForge' \
|
||||
--is_merge_dialect=True \
|
||||
--dialects 'american' 'british' 'australian' 'european' 'irish' 'canadian' 'indian'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Prepare VoxForge failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "VoxForge Data preparation done."
|
||||
exit 0
|
@ -0,0 +1,221 @@
|
||||
"""Prepare VoxForge dataset
|
||||
|
||||
Download, unpack and create manifest files.
|
||||
Manifest file is a json-format file with each line containing the
|
||||
meta data (i.e. audio filepath, transcript and audio duration)
|
||||
of each audio file in the data set.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import codecs
|
||||
import soundfile
|
||||
import json
|
||||
import argparse
|
||||
import shutil
|
||||
import subprocess
|
||||
from data_utils.utility import download_multi, unpack, getfile_insensitive
|
||||
|
||||
DATA_HOME = '~/.cache/paddle/dataset/speech'
|
||||
|
||||
DATA_URL = 'http://www.repository.voxforge1.org/downloads/SpeechCorpus/Trunk/' \
|
||||
'Audio/Main/16kHz_16bit'
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--target_dir",
|
||||
default=DATA_HOME + "/VoxForge",
|
||||
type=str,
|
||||
help="Directory to save the dataset. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--dialects",
|
||||
default=[
|
||||
'american', 'british', 'australian', 'european', 'irish', 'canadian',
|
||||
'indian'
|
||||
],
|
||||
nargs='+',
|
||||
type=str,
|
||||
help="Dialect types. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--is_merge_dialect",
|
||||
default=True,
|
||||
type=bool,
|
||||
help="If set True, manifests of american dialect and canadian dialect will "
|
||||
"be merged to american-canadian dialect; manifests of british "
|
||||
"dialect, irish dialect and australian dialect will be merged to "
|
||||
"commonwealth dialect. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--manifest_prefix",
|
||||
default="manifest",
|
||||
type=str,
|
||||
help="Filepath prefix for output manifests. (default: %(default)s)")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def download_and_unpack(target_dir, url):
|
||||
wget_args = '-q -l 1 -N -nd -c -e robots=off -A tgz -r -np'
|
||||
tgz_dir = os.path.join(target_dir, 'tgz')
|
||||
exit_code = download_multi(url, tgz_dir, wget_args)
|
||||
if exit_code != 0:
|
||||
print('Download tgz audio files failed with exit code %d.' % exit_code)
|
||||
else:
|
||||
print('Download done, start unpacking ...')
|
||||
audio_dir = os.path.join(target_dir, 'audio')
|
||||
for root, dirs, files in os.walk(tgz_dir):
|
||||
for file in files:
|
||||
print(file)
|
||||
if file.endswith('.tgz'):
|
||||
unpack(os.path.join(root, file), audio_dir)
|
||||
|
||||
|
||||
def select_dialects(target_dir, dialect_list):
|
||||
"""Classify audio files by dialect."""
|
||||
dialect_root_dir = os.path.join(target_dir, 'dialect')
|
||||
if os.path.exists(dialect_root_dir):
|
||||
shutil.rmtree(dialect_root_dir)
|
||||
os.mkdir(dialect_root_dir)
|
||||
audio_dir = os.path.abspath(os.path.join(target_dir, 'audio'))
|
||||
for dialect in dialect_list:
|
||||
# filter files by dialect
|
||||
command = 'find %s -iwholename "*etc/readme*" -exec egrep -iHl \
|
||||
"pronunciation dialect.*%s" {} \;' % (audio_dir, dialect)
|
||||
p = subprocess.Popen(
|
||||
command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, shell=True)
|
||||
output, err = p.communicate()
|
||||
dialect_dir = os.path.join(dialect_root_dir, dialect)
|
||||
if os.path.exists(dialect_dir):
|
||||
shutil.rmtree(dialect_dir)
|
||||
os.mkdir(dialect_dir)
|
||||
for path in output.splitlines():
|
||||
src_dir = os.path.dirname(os.path.dirname(path))
|
||||
link = os.path.basename(os.path.normpath(src_dir))
|
||||
os.symlink(src_dir, os.path.join(dialect_dir, link))
|
||||
|
||||
|
||||
def generate_manifest(data_dir, manifest_path):
|
||||
json_lines = []
|
||||
|
||||
for path in os.listdir(data_dir):
|
||||
audio_link = os.path.join(data_dir, path)
|
||||
assert os.path.islink(
|
||||
audio_link), '%s should be symbolic link.' % audio_link
|
||||
actual_audio_dir = os.path.abspath(os.readlink(audio_link))
|
||||
|
||||
audio_type = ''
|
||||
if os.path.isdir(os.path.join(actual_audio_dir, 'wav')):
|
||||
audio_type = 'wav'
|
||||
elif os.path.isdir(os.path.join(actual_audio_dir, 'flac')):
|
||||
audio_type = 'flac'
|
||||
else:
|
||||
print('Unknown audio type, skipped processing %s.' %
|
||||
actual_audio_dir)
|
||||
continue
|
||||
|
||||
etc_dir = os.path.join(actual_audio_dir, 'etc')
|
||||
prompts_file = os.path.join(etc_dir, 'PROMPTS')
|
||||
if not os.path.isfile(prompts_file):
|
||||
print('PROMPTS file missing, skip processing %s.' %
|
||||
actual_audio_dir)
|
||||
continue
|
||||
|
||||
readme_file = getfile_insensitive(os.path.join(etc_dir, 'README'))
|
||||
if readme_file is None:
|
||||
print('README file missing, skip processing %s.' % actual_audio_dir)
|
||||
continue
|
||||
|
||||
for line in file(prompts_file):
|
||||
u, trans = line.strip().split(None, 1)
|
||||
u_parts = u.split('/')
|
||||
|
||||
# try to format the date time
|
||||
try:
|
||||
speaker, date, sfx = u_parts[-3].split('-')
|
||||
obj = datetime.datetime.strptime(date, '%y.%m.%d')
|
||||
formatted = obj.strftime('%Y%m%d')
|
||||
u_parts[-3] = '-'.join([speaker, formatted, sfx])
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
if len(u_parts) < 2:
|
||||
u_parts = [audio_type] + u_parts
|
||||
u_parts[-2] = audio_type
|
||||
u_parts[-1] += '.' + audio_type
|
||||
u = os.path.join(actual_audio_dir, '/'.join(u_parts[-2:]))
|
||||
|
||||
if not os.path.isfile(u):
|
||||
print('Audio file missing, skip processing %s.' % u)
|
||||
continue
|
||||
|
||||
if os.stat(u).st_size == 0:
|
||||
print('Empty audio file, skip processing %s.' % u)
|
||||
continue
|
||||
|
||||
trans = trans.strip().replace('-', ' ')
|
||||
if not trans.isupper() or \
|
||||
not trans.strip().replace(' ', '').replace("'", "").isalpha():
|
||||
print("Transcript not normalized properly, skip processing %s."
|
||||
% u)
|
||||
continue
|
||||
|
||||
audio_data, samplerate = soundfile.read(u)
|
||||
duration = float(len(audio_data)) / samplerate
|
||||
json_lines.append(
|
||||
json.dumps({
|
||||
'audio_filepath': u,
|
||||
'duration': duration,
|
||||
'text': trans.lower()
|
||||
}))
|
||||
|
||||
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
|
||||
for line in json_lines:
|
||||
fout.write(line + '\n')
|
||||
|
||||
|
||||
def merge_manifests(manifest_files, save_path):
|
||||
lines = []
|
||||
for manifest_file in manifest_files:
|
||||
line = codecs.open(manifest_file, 'r', 'utf-8').readlines()
|
||||
lines += line
|
||||
|
||||
with codecs.open(save_path, 'w', 'utf-8') as fout:
|
||||
for line in lines:
|
||||
fout.write(line)
|
||||
|
||||
|
||||
def prepare_dataset(url, dialects, target_dir, manifest_prefix, is_merge):
|
||||
download_and_unpack(target_dir, url)
|
||||
select_dialects(target_dir, dialects)
|
||||
american_canadian_manifests = []
|
||||
commonwealth_manifests = []
|
||||
for dialect in dialects:
|
||||
dialect_dir = os.path.join(target_dir, 'dialect', dialect)
|
||||
manifest_fpath = manifest_prefix + '.' + dialect
|
||||
if dialect == 'american' or dialect == 'canadian':
|
||||
american_canadian_manifests.append(manifest_fpath)
|
||||
if dialect == 'australian' \
|
||||
or dialect == 'british' \
|
||||
or dialect == 'irish':
|
||||
commonwealth_manifests.append(manifest_fpath)
|
||||
generate_manifest(dialect_dir, manifest_fpath)
|
||||
|
||||
if is_merge:
|
||||
if len(american_canadian_manifests) > 0:
|
||||
manifest_fpath = manifest_prefix + '.american-canadian'
|
||||
merge_manifests(american_canadian_manifests, manifest_fpath)
|
||||
if len(commonwealth_manifests) > 0:
|
||||
manifest_fpath = manifest_prefix + '.commonwealth'
|
||||
merge_manifests(commonwealth_manifests, manifest_fpath)
|
||||
|
||||
|
||||
def main():
|
||||
if args.target_dir.startswith('~'):
|
||||
args.target_dir = os.path.expanduser(args.target_dir)
|
||||
|
||||
prepare_dataset(DATA_URL, args.dialects, args.target_dir,
|
||||
args.manifest_prefix, args.is_merge_dialect)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,685 @@
|
||||
"""Contains the audio segment class."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import io
|
||||
import struct
|
||||
import re
|
||||
import soundfile
|
||||
import resampy
|
||||
from scipy import signal
|
||||
import random
|
||||
import copy
|
||||
|
||||
|
||||
class AudioSegment(object):
|
||||
"""Monaural audio segment abstraction.
|
||||
|
||||
:param samples: Audio samples [num_samples x num_channels].
|
||||
:type samples: ndarray.float32
|
||||
:param sample_rate: Audio sample rate.
|
||||
:type sample_rate: int
|
||||
:raises TypeError: If the sample data type is not float or int.
|
||||
"""
|
||||
|
||||
def __init__(self, samples, sample_rate):
|
||||
"""Create audio segment from samples.
|
||||
|
||||
Samples are convert float32 internally, with int scaled to [-1, 1].
|
||||
"""
|
||||
self._samples = self._convert_samples_to_float32(samples)
|
||||
self._sample_rate = sample_rate
|
||||
if self._samples.ndim >= 2:
|
||||
self._samples = np.mean(self._samples, 1)
|
||||
|
||||
def __eq__(self, other):
|
||||
"""Return whether two objects are equal."""
|
||||
if type(other) is not type(self):
|
||||
return False
|
||||
if self._sample_rate != other._sample_rate:
|
||||
return False
|
||||
if self._samples.shape != other._samples.shape:
|
||||
return False
|
||||
if np.any(self.samples != other._samples):
|
||||
return False
|
||||
return True
|
||||
|
||||
def __ne__(self, other):
|
||||
"""Return whether two objects are unequal."""
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __str__(self):
|
||||
"""Return human-readable representation of segment."""
|
||||
return ("%s: num_samples=%d, sample_rate=%d, duration=%.2fsec, "
|
||||
"rms=%.2fdB" % (type(self), self.num_samples, self.sample_rate,
|
||||
self.duration, self.rms_db))
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, file):
|
||||
"""Create audio segment from audio file.
|
||||
|
||||
:param filepath: Filepath or file object to audio file.
|
||||
:type filepath: basestring|file
|
||||
:return: Audio segment instance.
|
||||
:rtype: AudioSegment
|
||||
"""
|
||||
if isinstance(file, basestring) and re.findall(r".seqbin_\d+$", file):
|
||||
return cls.from_sequence_file(file)
|
||||
else:
|
||||
samples, sample_rate = soundfile.read(file, dtype='float32')
|
||||
return cls(samples, sample_rate)
|
||||
|
||||
@classmethod
|
||||
def slice_from_file(cls, file, start=None, end=None):
|
||||
"""Loads a small section of an audio without having to load
|
||||
the entire file into the memory which can be incredibly wasteful.
|
||||
|
||||
:param file: Input audio filepath or file object.
|
||||
:type file: basestring|file
|
||||
:param start: Start time in seconds. If start is negative, it wraps
|
||||
around from the end. If not provided, this function
|
||||
reads from the very beginning.
|
||||
:type start: float
|
||||
:param end: End time in seconds. If end is negative, it wraps around
|
||||
from the end. If not provided, the default behvaior is
|
||||
to read to the end of the file.
|
||||
:type end: float
|
||||
:return: AudioSegment instance of the specified slice of the input
|
||||
audio file.
|
||||
:rtype: AudioSegment
|
||||
:raise ValueError: If start or end is incorrectly set, e.g. out of
|
||||
bounds in time.
|
||||
"""
|
||||
sndfile = soundfile.SoundFile(file)
|
||||
sample_rate = sndfile.samplerate
|
||||
duration = float(len(sndfile)) / sample_rate
|
||||
start = 0. if start is None else start
|
||||
end = 0. if end is None else end
|
||||
if start < 0.0:
|
||||
start += duration
|
||||
if end < 0.0:
|
||||
end += duration
|
||||
if start < 0.0:
|
||||
raise ValueError("The slice start position (%f s) is out of "
|
||||
"bounds." % start)
|
||||
if end < 0.0:
|
||||
raise ValueError("The slice end position (%f s) is out of bounds." %
|
||||
end)
|
||||
if start > end:
|
||||
raise ValueError("The slice start position (%f s) is later than "
|
||||
"the slice end position (%f s)." % (start, end))
|
||||
if end > duration:
|
||||
raise ValueError("The slice end position (%f s) is out of bounds "
|
||||
"(> %f s)" % (end, duration))
|
||||
start_frame = int(start * sample_rate)
|
||||
end_frame = int(end * sample_rate)
|
||||
sndfile.seek(start_frame)
|
||||
data = sndfile.read(frames=end_frame - start_frame, dtype='float32')
|
||||
return cls(data, sample_rate)
|
||||
|
||||
@classmethod
|
||||
def from_sequence_file(cls, filepath):
|
||||
"""Create audio segment from sequence file. Sequence file is a binary
|
||||
file containing a collection of multiple audio files, with several
|
||||
header bytes in the head indicating the offsets of each audio byte data
|
||||
chunk.
|
||||
|
||||
The format is:
|
||||
|
||||
4 bytes (int, version),
|
||||
4 bytes (int, num of utterance),
|
||||
4 bytes (int, bytes per header),
|
||||
[bytes_per_header*(num_utterance+1)] bytes (offsets for each audio),
|
||||
audio_bytes_data_of_1st_utterance,
|
||||
audio_bytes_data_of_2nd_utterance,
|
||||
......
|
||||
|
||||
Sequence file name must end with ".seqbin". And the filename of the 5th
|
||||
utterance's audio file in sequence file "xxx.seqbin" must be
|
||||
"xxx.seqbin_5", with "5" indicating the utterance index within this
|
||||
sequence file (starting from 1).
|
||||
|
||||
:param filepath: Filepath of sequence file.
|
||||
:type filepath: basestring
|
||||
:return: Audio segment instance.
|
||||
:rtype: AudioSegment
|
||||
"""
|
||||
# parse filepath
|
||||
matches = re.match(r"(.+\.seqbin)_(\d+)", filepath)
|
||||
if matches is None:
|
||||
raise IOError("File type of %s is not supported" % filepath)
|
||||
filename = matches.group(1)
|
||||
fileno = int(matches.group(2))
|
||||
|
||||
# read headers
|
||||
f = open(filename, 'rb')
|
||||
version = f.read(4)
|
||||
num_utterances = struct.unpack("i", f.read(4))[0]
|
||||
bytes_per_header = struct.unpack("i", f.read(4))[0]
|
||||
header_bytes = f.read(bytes_per_header * (num_utterances + 1))
|
||||
header = [
|
||||
struct.unpack("i", header_bytes[bytes_per_header * i:
|
||||
bytes_per_header * (i + 1)])[0]
|
||||
for i in range(num_utterances + 1)
|
||||
]
|
||||
|
||||
# read audio bytes
|
||||
f.seek(header[fileno - 1])
|
||||
audio_bytes = f.read(header[fileno] - header[fileno - 1])
|
||||
f.close()
|
||||
|
||||
# create audio segment
|
||||
try:
|
||||
return cls.from_bytes(audio_bytes)
|
||||
except Exception as e:
|
||||
samples = np.frombuffer(audio_bytes, dtype='int16')
|
||||
return cls(samples=samples, sample_rate=8000)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, bytes):
|
||||
"""Create audio segment from a byte string containing audio samples.
|
||||
|
||||
:param bytes: Byte string containing audio samples.
|
||||
:type bytes: str
|
||||
:return: Audio segment instance.
|
||||
:rtype: AudioSegment
|
||||
"""
|
||||
samples, sample_rate = soundfile.read(
|
||||
io.BytesIO(bytes), dtype='float32')
|
||||
return cls(samples, sample_rate)
|
||||
|
||||
@classmethod
|
||||
def concatenate(cls, *segments):
|
||||
"""Concatenate an arbitrary number of audio segments together.
|
||||
|
||||
:param *segments: Input audio segments to be concatenated.
|
||||
:type *segments: tuple of AudioSegment
|
||||
:return: Audio segment instance as concatenating results.
|
||||
:rtype: AudioSegment
|
||||
:raises ValueError: If the number of segments is zero, or if the
|
||||
sample_rate of any segments does not match.
|
||||
:raises TypeError: If any segment is not AudioSegment instance.
|
||||
"""
|
||||
# Perform basic sanity-checks.
|
||||
if len(segments) == 0:
|
||||
raise ValueError("No audio segments are given to concatenate.")
|
||||
sample_rate = segments[0]._sample_rate
|
||||
for seg in segments:
|
||||
if sample_rate != seg._sample_rate:
|
||||
raise ValueError("Can't concatenate segments with "
|
||||
"different sample rates")
|
||||
if type(seg) is not cls:
|
||||
raise TypeError("Only audio segments of the same type "
|
||||
"can be concatenated.")
|
||||
samples = np.concatenate([seg.samples for seg in segments])
|
||||
return cls(samples, sample_rate)
|
||||
|
||||
@classmethod
|
||||
def make_silence(cls, duration, sample_rate):
|
||||
"""Creates a silent audio segment of the given duration and sample rate.
|
||||
|
||||
:param duration: Length of silence in seconds.
|
||||
:type duration: float
|
||||
:param sample_rate: Sample rate.
|
||||
:type sample_rate: float
|
||||
:return: Silent AudioSegment instance of the given duration.
|
||||
:rtype: AudioSegment
|
||||
"""
|
||||
samples = np.zeros(int(duration * sample_rate))
|
||||
return cls(samples, sample_rate)
|
||||
|
||||
def to_wav_file(self, filepath, dtype='float32'):
|
||||
"""Save audio segment to disk as wav file.
|
||||
|
||||
:param filepath: WAV filepath or file object to save the
|
||||
audio segment.
|
||||
:type filepath: basestring|file
|
||||
:param dtype: Subtype for audio file. Options: 'int16', 'int32',
|
||||
'float32', 'float64'. Default is 'float32'.
|
||||
:type dtype: str
|
||||
:raises TypeError: If dtype is not supported.
|
||||
"""
|
||||
samples = self._convert_samples_from_float32(self._samples, dtype)
|
||||
subtype_map = {
|
||||
'int16': 'PCM_16',
|
||||
'int32': 'PCM_32',
|
||||
'float32': 'FLOAT',
|
||||
'float64': 'DOUBLE'
|
||||
}
|
||||
soundfile.write(
|
||||
filepath,
|
||||
samples,
|
||||
self._sample_rate,
|
||||
format='WAV',
|
||||
subtype=subtype_map[dtype])
|
||||
|
||||
def superimpose(self, other):
|
||||
"""Add samples from another segment to those of this segment
|
||||
(sample-wise addition, not segment concatenation).
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param other: Segment containing samples to be added in.
|
||||
:type other: AudioSegments
|
||||
:raise TypeError: If type of two segments don't match.
|
||||
:raise ValueError: If the sample rates of the two segments are not
|
||||
equal, or if the lengths of segments don't match.
|
||||
"""
|
||||
if isinstance(other, type(self)):
|
||||
raise TypeError("Cannot add segments of different types: %s "
|
||||
"and %s." % (type(self), type(other)))
|
||||
if self._sample_rate != other._sample_rate:
|
||||
raise ValueError("Sample rates must match to add segments.")
|
||||
if len(self._samples) != len(other._samples):
|
||||
raise ValueError("Segment lengths must match to add segments.")
|
||||
self._samples += other._samples
|
||||
|
||||
def to_bytes(self, dtype='float32'):
|
||||
"""Create a byte string containing the audio content.
|
||||
|
||||
:param dtype: Data type for export samples. Options: 'int16', 'int32',
|
||||
'float32', 'float64'. Default is 'float32'.
|
||||
:type dtype: str
|
||||
:return: Byte string containing audio content.
|
||||
:rtype: str
|
||||
"""
|
||||
samples = self._convert_samples_from_float32(self._samples, dtype)
|
||||
return samples.tostring()
|
||||
|
||||
def gain_db(self, gain):
|
||||
"""Apply gain in decibels to samples.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param gain: Gain in decibels to apply to samples.
|
||||
:type gain: float|1darray
|
||||
"""
|
||||
self._samples *= 10.**(gain / 20.)
|
||||
|
||||
def change_speed(self, speed_rate):
|
||||
"""Change the audio speed by linear interpolation.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param speed_rate: Rate of speed change:
|
||||
speed_rate > 1.0, speed up the audio;
|
||||
speed_rate = 1.0, unchanged;
|
||||
speed_rate < 1.0, slow down the audio;
|
||||
speed_rate <= 0.0, not allowed, raise ValueError.
|
||||
:type speed_rate: float
|
||||
:raises ValueError: If speed_rate <= 0.0.
|
||||
"""
|
||||
if speed_rate <= 0:
|
||||
raise ValueError("speed_rate should be greater than zero.")
|
||||
old_length = self._samples.shape[0]
|
||||
new_length = int(old_length / speed_rate)
|
||||
old_indices = np.arange(old_length)
|
||||
new_indices = np.linspace(start=0, stop=old_length, num=new_length)
|
||||
self._samples = np.interp(new_indices, old_indices, self._samples)
|
||||
|
||||
def normalize(self, target_db=-20, max_gain_db=300.0):
|
||||
"""Normalize audio to be of the desired RMS value in decibels.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param target_db: Target RMS value in decibels. This value should be
|
||||
less than 0.0 as 0.0 is full-scale audio.
|
||||
:type target_db: float
|
||||
:param max_gain_db: Max amount of gain in dB that can be applied for
|
||||
normalization. This is to prevent nans when
|
||||
attempting to normalize a signal consisting of
|
||||
all zeros.
|
||||
:type max_gain_db: float
|
||||
:raises ValueError: If the required gain to normalize the segment to
|
||||
the target_db value exceeds max_gain_db.
|
||||
"""
|
||||
gain = target_db - self.rms_db
|
||||
if gain > max_gain_db:
|
||||
raise ValueError(
|
||||
"Unable to normalize segment to %f dB because the "
|
||||
"the probable gain have exceeds max_gain_db (%f dB)" %
|
||||
(target_db, max_gain_db))
|
||||
self.gain_db(min(max_gain_db, target_db - self.rms_db))
|
||||
|
||||
def normalize_online_bayesian(self,
|
||||
target_db,
|
||||
prior_db,
|
||||
prior_samples,
|
||||
startup_delay=0.0):
|
||||
"""Normalize audio using a production-compatible online/causal
|
||||
algorithm. This uses an exponential likelihood and gamma prior to
|
||||
make online estimates of the RMS even when there are very few samples.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param target_db: Target RMS value in decibels.
|
||||
:type target_bd: float
|
||||
:param prior_db: Prior RMS estimate in decibels.
|
||||
:type prior_db: float
|
||||
:param prior_samples: Prior strength in number of samples.
|
||||
:type prior_samples: float
|
||||
:param startup_delay: Default 0.0s. If provided, this function will
|
||||
accrue statistics for the first startup_delay
|
||||
seconds before applying online normalization.
|
||||
:type startup_delay: float
|
||||
"""
|
||||
# Estimate total RMS online.
|
||||
startup_sample_idx = min(self.num_samples - 1,
|
||||
int(self.sample_rate * startup_delay))
|
||||
prior_mean_squared = 10.**(prior_db / 10.)
|
||||
prior_sum_of_squares = prior_mean_squared * prior_samples
|
||||
cumsum_of_squares = np.cumsum(self.samples**2)
|
||||
sample_count = np.arange(self.num_samples) + 1
|
||||
if startup_sample_idx > 0:
|
||||
cumsum_of_squares[:startup_sample_idx] = \
|
||||
cumsum_of_squares[startup_sample_idx]
|
||||
sample_count[:startup_sample_idx] = \
|
||||
sample_count[startup_sample_idx]
|
||||
mean_squared_estimate = ((cumsum_of_squares + prior_sum_of_squares) /
|
||||
(sample_count + prior_samples))
|
||||
rms_estimate_db = 10 * np.log10(mean_squared_estimate)
|
||||
# Compute required time-varying gain.
|
||||
gain_db = target_db - rms_estimate_db
|
||||
self.gain_db(gain_db)
|
||||
|
||||
def resample(self, target_sample_rate, filter='kaiser_best'):
|
||||
"""Resample the audio to a target sample rate.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param target_sample_rate: Target sample rate.
|
||||
:type target_sample_rate: int
|
||||
:param filter: The resampling filter to use one of {'kaiser_best',
|
||||
'kaiser_fast'}.
|
||||
:type filter: str
|
||||
"""
|
||||
self._samples = resampy.resample(
|
||||
self.samples, self.sample_rate, target_sample_rate, filter=filter)
|
||||
self._sample_rate = target_sample_rate
|
||||
|
||||
def pad_silence(self, duration, sides='both'):
|
||||
"""Pad this audio sample with a period of silence.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param duration: Length of silence in seconds to pad.
|
||||
:type duration: float
|
||||
:param sides: Position for padding:
|
||||
'beginning' - adds silence in the beginning;
|
||||
'end' - adds silence in the end;
|
||||
'both' - adds silence in both the beginning and the end.
|
||||
:type sides: str
|
||||
:raises ValueError: If sides is not supported.
|
||||
"""
|
||||
if duration == 0.0:
|
||||
return self
|
||||
cls = type(self)
|
||||
silence = self.make_silence(duration, self._sample_rate)
|
||||
if sides == "beginning":
|
||||
padded = cls.concatenate(silence, self)
|
||||
elif sides == "end":
|
||||
padded = cls.concatenate(self, silence)
|
||||
elif sides == "both":
|
||||
padded = cls.concatenate(silence, self, silence)
|
||||
else:
|
||||
raise ValueError("Unknown value for the sides %s" % sides)
|
||||
self._samples = padded._samples
|
||||
|
||||
def shift(self, shift_ms):
|
||||
"""Shift the audio in time. If `shift_ms` is positive, shift with time
|
||||
advance; if negative, shift with time delay. Silence are padded to
|
||||
keep the duration unchanged.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param shift_ms: Shift time in millseconds. If positive, shift with
|
||||
time advance; if negative; shift with time delay.
|
||||
:type shift_ms: float
|
||||
:raises ValueError: If shift_ms is longer than audio duration.
|
||||
"""
|
||||
if abs(shift_ms) / 1000.0 > self.duration:
|
||||
raise ValueError("Absolute value of shift_ms should be smaller "
|
||||
"than audio duration.")
|
||||
shift_samples = int(shift_ms * self._sample_rate / 1000)
|
||||
if shift_samples > 0:
|
||||
# time advance
|
||||
self._samples[:-shift_samples] = self._samples[shift_samples:]
|
||||
self._samples[-shift_samples:] = 0
|
||||
elif shift_samples < 0:
|
||||
# time delay
|
||||
self._samples[-shift_samples:] = self._samples[:shift_samples]
|
||||
self._samples[:-shift_samples] = 0
|
||||
|
||||
def subsegment(self, start_sec=None, end_sec=None):
|
||||
"""Cut the AudioSegment between given boundaries.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param start_sec: Beginning of subsegment in seconds.
|
||||
:type start_sec: float
|
||||
:param end_sec: End of subsegment in seconds.
|
||||
:type end_sec: float
|
||||
:raise ValueError: If start_sec or end_sec is incorrectly set, e.g. out
|
||||
of bounds in time.
|
||||
"""
|
||||
start_sec = 0.0 if start_sec is None else start_sec
|
||||
end_sec = self.duration if end_sec is None else end_sec
|
||||
if start_sec < 0.0:
|
||||
start_sec = self.duration + start_sec
|
||||
if end_sec < 0.0:
|
||||
end_sec = self.duration + end_sec
|
||||
if start_sec < 0.0:
|
||||
raise ValueError("The slice start position (%f s) is out of "
|
||||
"bounds." % start_sec)
|
||||
if end_sec < 0.0:
|
||||
raise ValueError("The slice end position (%f s) is out of bounds." %
|
||||
end_sec)
|
||||
if start_sec > end_sec:
|
||||
raise ValueError("The slice start position (%f s) is later than "
|
||||
"the end position (%f s)." % (start_sec, end_sec))
|
||||
if end_sec > self.duration:
|
||||
raise ValueError("The slice end position (%f s) is out of bounds "
|
||||
"(> %f s)" % (end_sec, self.duration))
|
||||
start_sample = int(round(start_sec * self._sample_rate))
|
||||
end_sample = int(round(end_sec * self._sample_rate))
|
||||
self._samples = self._samples[start_sample:end_sample]
|
||||
|
||||
def random_subsegment(self, subsegment_length, rng=None):
|
||||
"""Cut the specified length of the audiosegment randomly.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param subsegment_length: Subsegment length in seconds.
|
||||
:type subsegment_length: float
|
||||
:param rng: Random number generator state.
|
||||
:type rng: random.Random
|
||||
:raises ValueError: If the length of subsegment is greater than
|
||||
the origineal segemnt.
|
||||
"""
|
||||
rng = random.Random() if rng is None else rng
|
||||
if subsegment_length > self.duration:
|
||||
raise ValueError("Length of subsegment must not be greater "
|
||||
"than original segment.")
|
||||
start_time = rng.uniform(0.0, self.duration - subsegment_length)
|
||||
self.subsegment(start_time, start_time + subsegment_length)
|
||||
|
||||
def convolve(self, impulse_segment, allow_resample=False):
|
||||
"""Convolve this audio segment with the given impulse segment.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param impulse_segment: Impulse response segments.
|
||||
:type impulse_segment: AudioSegment
|
||||
:param allow_resample: Indicates whether resampling is allowed when
|
||||
the impulse_segment has a different sample
|
||||
rate from this signal.
|
||||
:type allow_resample: bool
|
||||
:raises ValueError: If the sample rate is not match between two
|
||||
audio segments when resample is not allowed.
|
||||
"""
|
||||
if allow_resample and self.sample_rate != impulse_segment.sample_rate:
|
||||
impulse_segment.resample(self.sample_rate)
|
||||
if self.sample_rate != impulse_segment.sample_rate:
|
||||
raise ValueError("Impulse segment's sample rate (%d Hz) is not "
|
||||
"equal to base signal sample rate (%d Hz)." %
|
||||
(impulse_segment.sample_rate, self.sample_rate))
|
||||
samples = signal.fftconvolve(self.samples, impulse_segment.samples,
|
||||
"full")
|
||||
self._samples = samples
|
||||
|
||||
def convolve_and_normalize(self, impulse_segment, allow_resample=False):
|
||||
"""Convolve and normalize the resulting audio segment so that it
|
||||
has the same average power as the input signal.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param impulse_segment: Impulse response segments.
|
||||
:type impulse_segment: AudioSegment
|
||||
:param allow_resample: Indicates whether resampling is allowed when
|
||||
the impulse_segment has a different sample
|
||||
rate from this signal.
|
||||
:type allow_resample: bool
|
||||
"""
|
||||
target_db = self.rms_db
|
||||
self.convolve(impulse_segment, allow_resample=allow_resample)
|
||||
self.normalize(target_db)
|
||||
|
||||
def add_noise(self,
|
||||
noise,
|
||||
snr_dB,
|
||||
allow_downsampling=False,
|
||||
max_gain_db=300.0,
|
||||
rng=None):
|
||||
"""Add the given noise segment at a specific signal-to-noise ratio.
|
||||
If the noise segment is longer than this segment, a random subsegment
|
||||
of matching length is sampled from it and used instead.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param noise: Noise signal to add.
|
||||
:type noise: AudioSegment
|
||||
:param snr_dB: Signal-to-Noise Ratio, in decibels.
|
||||
:type snr_dB: float
|
||||
:param allow_downsampling: Whether to allow the noise signal to be
|
||||
downsampled to match the base signal sample
|
||||
rate.
|
||||
:type allow_downsampling: bool
|
||||
:param max_gain_db: Maximum amount of gain to apply to noise signal
|
||||
before adding it in. This is to prevent attempting
|
||||
to apply infinite gain to a zero signal.
|
||||
:type max_gain_db: float
|
||||
:param rng: Random number generator state.
|
||||
:type rng: None|random.Random
|
||||
:raises ValueError: If the sample rate does not match between the two
|
||||
audio segments when downsampling is not allowed, or
|
||||
if the duration of noise segments is shorter than
|
||||
original audio segments.
|
||||
"""
|
||||
rng = random.Random() if rng is None else rng
|
||||
if allow_downsampling and noise.sample_rate > self.sample_rate:
|
||||
noise = noise.resample(self.sample_rate)
|
||||
if noise.sample_rate != self.sample_rate:
|
||||
raise ValueError("Noise sample rate (%d Hz) is not equal to base "
|
||||
"signal sample rate (%d Hz)." % (noise.sample_rate,
|
||||
self.sample_rate))
|
||||
if noise.duration < self.duration:
|
||||
raise ValueError("Noise signal (%f sec) must be at least as long as"
|
||||
" base signal (%f sec)." %
|
||||
(noise.duration, self.duration))
|
||||
noise_gain_db = min(self.rms_db - noise.rms_db - snr_dB, max_gain_db)
|
||||
noise_new = copy.deepcopy(noise)
|
||||
noise_new.random_subsegment(self.duration, rng=rng)
|
||||
noise_new.gain_db(noise_gain_db)
|
||||
self.superimpose(noise_new)
|
||||
|
||||
@property
|
||||
def samples(self):
|
||||
"""Return audio samples.
|
||||
|
||||
:return: Audio samples.
|
||||
:rtype: ndarray
|
||||
"""
|
||||
return self._samples.copy()
|
||||
|
||||
@property
|
||||
def sample_rate(self):
|
||||
"""Return audio sample rate.
|
||||
|
||||
:return: Audio sample rate.
|
||||
:rtype: int
|
||||
"""
|
||||
return self._sample_rate
|
||||
|
||||
@property
|
||||
def num_samples(self):
|
||||
"""Return number of samples.
|
||||
|
||||
:return: Number of samples.
|
||||
:rtype: int
|
||||
"""
|
||||
return self._samples.shape[0]
|
||||
|
||||
@property
|
||||
def duration(self):
|
||||
"""Return audio duration.
|
||||
|
||||
:return: Audio duration in seconds.
|
||||
:rtype: float
|
||||
"""
|
||||
return self._samples.shape[0] / float(self._sample_rate)
|
||||
|
||||
@property
|
||||
def rms_db(self):
|
||||
"""Return root mean square energy of the audio in decibels.
|
||||
|
||||
:return: Root mean square energy in decibels.
|
||||
:rtype: float
|
||||
"""
|
||||
# square root => multiply by 10 instead of 20 for dBs
|
||||
mean_square = np.mean(self._samples**2)
|
||||
return 10 * np.log10(mean_square)
|
||||
|
||||
def _convert_samples_to_float32(self, samples):
|
||||
"""Convert sample type to float32.
|
||||
|
||||
Audio sample type is usually integer or float-point.
|
||||
Integers will be scaled to [-1, 1] in float32.
|
||||
"""
|
||||
float32_samples = samples.astype('float32')
|
||||
if samples.dtype in np.sctypes['int']:
|
||||
bits = np.iinfo(samples.dtype).bits
|
||||
float32_samples *= (1. / 2**(bits - 1))
|
||||
elif samples.dtype in np.sctypes['float']:
|
||||
pass
|
||||
else:
|
||||
raise TypeError("Unsupported sample type: %s." % samples.dtype)
|
||||
return float32_samples
|
||||
|
||||
def _convert_samples_from_float32(self, samples, dtype):
|
||||
"""Convert sample type from float32 to dtype.
|
||||
|
||||
Audio sample type is usually integer or float-point. For integer
|
||||
type, float32 will be rescaled from [-1, 1] to the maximum range
|
||||
supported by the integer type.
|
||||
|
||||
This is for writing a audio file.
|
||||
"""
|
||||
dtype = np.dtype(dtype)
|
||||
output_samples = samples.copy()
|
||||
if dtype in np.sctypes['int']:
|
||||
bits = np.iinfo(dtype).bits
|
||||
output_samples *= (2**(bits - 1) / 1.)
|
||||
min_val = np.iinfo(dtype).min
|
||||
max_val = np.iinfo(dtype).max
|
||||
output_samples[output_samples > max_val] = max_val
|
||||
output_samples[output_samples < min_val] = min_val
|
||||
elif samples.dtype in np.sctypes['float']:
|
||||
min_val = np.finfo(dtype).min
|
||||
max_val = np.finfo(dtype).max
|
||||
output_samples[output_samples > max_val] = max_val
|
||||
output_samples[output_samples < min_val] = min_val
|
||||
else:
|
||||
raise TypeError("Unsupported sample type: %s." % samples.dtype)
|
||||
return output_samples.astype(dtype)
|
@ -0,0 +1,124 @@
|
||||
"""Contains the data augmentation pipeline."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import random
|
||||
from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor
|
||||
from data_utils.augmentor.shift_perturb import ShiftPerturbAugmentor
|
||||
from data_utils.augmentor.speed_perturb import SpeedPerturbAugmentor
|
||||
from data_utils.augmentor.noise_perturb import NoisePerturbAugmentor
|
||||
from data_utils.augmentor.impulse_response import ImpulseResponseAugmentor
|
||||
from data_utils.augmentor.resample import ResampleAugmentor
|
||||
from data_utils.augmentor.online_bayesian_normalization import \
|
||||
OnlineBayesianNormalizationAugmentor
|
||||
|
||||
|
||||
class AugmentationPipeline(object):
|
||||
"""Build a pre-processing pipeline with various augmentation models.Such a
|
||||
data augmentation pipeline is oftern leveraged to augment the training
|
||||
samples to make the model invariant to certain types of perturbations in the
|
||||
real world, improving model's generalization ability.
|
||||
|
||||
The pipeline is built according the the augmentation configuration in json
|
||||
string, e.g.
|
||||
|
||||
.. code-block::
|
||||
|
||||
[ {
|
||||
"type": "noise",
|
||||
"params": {"min_snr_dB": 10,
|
||||
"max_snr_dB": 20,
|
||||
"noise_manifest_path": "datasets/manifest.noise"},
|
||||
"prob": 0.0
|
||||
},
|
||||
{
|
||||
"type": "speed",
|
||||
"params": {"min_speed_rate": 0.9,
|
||||
"max_speed_rate": 1.1},
|
||||
"prob": 1.0
|
||||
},
|
||||
{
|
||||
"type": "shift",
|
||||
"params": {"min_shift_ms": -5,
|
||||
"max_shift_ms": 5},
|
||||
"prob": 1.0
|
||||
},
|
||||
{
|
||||
"type": "volume",
|
||||
"params": {"min_gain_dBFS": -10,
|
||||
"max_gain_dBFS": 10},
|
||||
"prob": 0.0
|
||||
},
|
||||
{
|
||||
"type": "bayesian_normal",
|
||||
"params": {"target_db": -20,
|
||||
"prior_db": -20,
|
||||
"prior_samples": 100},
|
||||
"prob": 0.0
|
||||
}
|
||||
]
|
||||
|
||||
This augmentation configuration inserts two augmentation models
|
||||
into the pipeline, with one is VolumePerturbAugmentor and the other
|
||||
SpeedPerturbAugmentor. "prob" indicates the probability of the current
|
||||
augmentor to take effect. If "prob" is zero, the augmentor does not take
|
||||
effect.
|
||||
|
||||
:param augmentation_config: Augmentation configuration in json string.
|
||||
:type augmentation_config: str
|
||||
:param random_seed: Random seed.
|
||||
:type random_seed: int
|
||||
:raises ValueError: If the augmentation json config is in incorrect format".
|
||||
"""
|
||||
|
||||
def __init__(self, augmentation_config, random_seed=0):
|
||||
self._rng = random.Random(random_seed)
|
||||
self._augmentors, self._rates = self._parse_pipeline_from(
|
||||
augmentation_config)
|
||||
|
||||
def transform_audio(self, audio_segment):
|
||||
"""Run the pre-processing pipeline for data augmentation.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param audio_segment: Audio segment to process.
|
||||
:type audio_segment: AudioSegmenet|SpeechSegment
|
||||
"""
|
||||
for augmentor, rate in zip(self._augmentors, self._rates):
|
||||
if self._rng.uniform(0., 1.) < rate:
|
||||
augmentor.transform_audio(audio_segment)
|
||||
|
||||
def _parse_pipeline_from(self, config_json):
|
||||
"""Parse the config json to build a augmentation pipelien."""
|
||||
try:
|
||||
configs = json.loads(config_json)
|
||||
augmentors = [
|
||||
self._get_augmentor(config["type"], config["params"])
|
||||
for config in configs
|
||||
]
|
||||
rates = [config["prob"] for config in configs]
|
||||
except Exception as e:
|
||||
raise ValueError("Failed to parse the augmentation config json: "
|
||||
"%s" % str(e))
|
||||
return augmentors, rates
|
||||
|
||||
def _get_augmentor(self, augmentor_type, params):
|
||||
"""Return an augmentation model by the type name, and pass in params."""
|
||||
if augmentor_type == "volume":
|
||||
return VolumePerturbAugmentor(self._rng, **params)
|
||||
elif augmentor_type == "shift":
|
||||
return ShiftPerturbAugmentor(self._rng, **params)
|
||||
elif augmentor_type == "speed":
|
||||
return SpeedPerturbAugmentor(self._rng, **params)
|
||||
elif augmentor_type == "resample":
|
||||
return ResampleAugmentor(self._rng, **params)
|
||||
elif augmentor_type == "bayesian_normal":
|
||||
return OnlineBayesianNormalizationAugmentor(self._rng, **params)
|
||||
elif augmentor_type == "noise":
|
||||
return NoisePerturbAugmentor(self._rng, **params)
|
||||
elif augmentor_type == "impulse":
|
||||
return ImpulseResponseAugmentor(self._rng, **params)
|
||||
else:
|
||||
raise ValueError("Unknown augmentor type [%s]." % augmentor_type)
|
@ -0,0 +1,33 @@
|
||||
"""Contains the abstract base class for augmentation models."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
class AugmentorBase(object):
|
||||
"""Abstract base class for augmentation model (augmentor) class.
|
||||
All augmentor classes should inherit from this class, and implement the
|
||||
following abstract methods.
|
||||
"""
|
||||
|
||||
__metaclass__ = ABCMeta
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def transform_audio(self, audio_segment):
|
||||
"""Adds various effects to the input audio segment. Such effects
|
||||
will augment the training data to make the model invariant to certain
|
||||
types of perturbations in the real world, improving model's
|
||||
generalization ability.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param audio_segment: Audio segment to add effects to.
|
||||
:type audio_segment: AudioSegmenet|SpeechSegment
|
||||
"""
|
||||
pass
|
@ -0,0 +1,34 @@
|
||||
"""Contains the impulse response augmentation model."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from data_utils.augmentor.base import AugmentorBase
|
||||
from data_utils.utility import read_manifest
|
||||
from data_utils.audio import AudioSegment
|
||||
|
||||
|
||||
class ImpulseResponseAugmentor(AugmentorBase):
|
||||
"""Augmentation model for adding impulse response effect.
|
||||
|
||||
:param rng: Random generator object.
|
||||
:type rng: random.Random
|
||||
:param impulse_manifest_path: Manifest path for impulse audio data.
|
||||
:type impulse_manifest_path: basestring
|
||||
"""
|
||||
|
||||
def __init__(self, rng, impulse_manifest_path):
|
||||
self._rng = rng
|
||||
self._impulse_manifest = read_manifest(impulse_manifest_path)
|
||||
|
||||
def transform_audio(self, audio_segment):
|
||||
"""Add impulse response effect.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param audio_segment: Audio segment to add effects to.
|
||||
:type audio_segment: AudioSegmenet|SpeechSegment
|
||||
"""
|
||||
impulse_json = self._rng.sample(self._impulse_manifest, 1)[0]
|
||||
impulse_segment = AudioSegment.from_file(impulse_json['audio_filepath'])
|
||||
audio_segment.convolve(impulse_segment, allow_resample=True)
|
@ -0,0 +1,49 @@
|
||||
"""Contains the noise perturb augmentation model."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from data_utils.augmentor.base import AugmentorBase
|
||||
from data_utils.utility import read_manifest
|
||||
from data_utils.audio import AudioSegment
|
||||
|
||||
|
||||
class NoisePerturbAugmentor(AugmentorBase):
|
||||
"""Augmentation model for adding background noise.
|
||||
|
||||
:param rng: Random generator object.
|
||||
:type rng: random.Random
|
||||
:param min_snr_dB: Minimal signal noise ratio, in decibels.
|
||||
:type min_snr_dB: float
|
||||
:param max_snr_dB: Maximal signal noise ratio, in decibels.
|
||||
:type max_snr_dB: float
|
||||
:param noise_manifest_path: Manifest path for noise audio data.
|
||||
:type noise_manifest_path: basestring
|
||||
"""
|
||||
|
||||
def __init__(self, rng, min_snr_dB, max_snr_dB, noise_manifest_path):
|
||||
self._min_snr_dB = min_snr_dB
|
||||
self._max_snr_dB = max_snr_dB
|
||||
self._rng = rng
|
||||
self._noise_manifest = read_manifest(manifest_path=noise_manifest_path)
|
||||
|
||||
def transform_audio(self, audio_segment):
|
||||
"""Add background noise audio.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param audio_segment: Audio segment to add effects to.
|
||||
:type audio_segment: AudioSegmenet|SpeechSegment
|
||||
"""
|
||||
noise_json = self._rng.sample(self._noise_manifest, 1)[0]
|
||||
if noise_json['duration'] < audio_segment.duration:
|
||||
raise RuntimeError("The duration of sampled noise audio is smaller "
|
||||
"than the audio segment to add effects to.")
|
||||
diff_duration = noise_json['duration'] - audio_segment.duration
|
||||
start = self._rng.uniform(0, diff_duration)
|
||||
end = start + audio_segment.duration
|
||||
noise_segment = AudioSegment.slice_from_file(
|
||||
noise_json['audio_filepath'], start=start, end=end)
|
||||
snr_dB = self._rng.uniform(self._min_snr_dB, self._max_snr_dB)
|
||||
audio_segment.add_noise(
|
||||
noise_segment, snr_dB, allow_downsampling=True, rng=self._rng)
|
@ -0,0 +1,48 @@
|
||||
"""Contain the online bayesian normalization augmentation model."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from data_utils.augmentor.base import AugmentorBase
|
||||
|
||||
|
||||
class OnlineBayesianNormalizationAugmentor(AugmentorBase):
|
||||
"""Augmentation model for adding online bayesian normalization.
|
||||
|
||||
:param rng: Random generator object.
|
||||
:type rng: random.Random
|
||||
:param target_db: Target RMS value in decibels.
|
||||
:type target_db: float
|
||||
:param prior_db: Prior RMS estimate in decibels.
|
||||
:type prior_db: float
|
||||
:param prior_samples: Prior strength in number of samples.
|
||||
:type prior_samples: int
|
||||
:param startup_delay: Default 0.0s. If provided, this function will
|
||||
accrue statistics for the first startup_delay
|
||||
seconds before applying online normalization.
|
||||
:type starup_delay: float.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
rng,
|
||||
target_db,
|
||||
prior_db,
|
||||
prior_samples,
|
||||
startup_delay=0.0):
|
||||
self._target_db = target_db
|
||||
self._prior_db = prior_db
|
||||
self._prior_samples = prior_samples
|
||||
self._rng = rng
|
||||
self._startup_delay = startup_delay
|
||||
|
||||
def transform_audio(self, audio_segment):
|
||||
"""Normalizes the input audio using the online Bayesian approach.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param audio_segment: Audio segment to add effects to.
|
||||
:type audio_segment: AudioSegment|SpeechSegment
|
||||
"""
|
||||
audio_segment.normalize_online_bayesian(self._target_db, self._prior_db,
|
||||
self._prior_samples,
|
||||
self._startup_delay)
|
@ -0,0 +1,33 @@
|
||||
"""Contain the resample augmentation model."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from data_utils.augmentor.base import AugmentorBase
|
||||
|
||||
|
||||
class ResampleAugmentor(AugmentorBase):
|
||||
"""Augmentation model for resampling.
|
||||
|
||||
See more info here:
|
||||
https://ccrma.stanford.edu/~jos/resample/index.html
|
||||
|
||||
:param rng: Random generator object.
|
||||
:type rng: random.Random
|
||||
:param new_sample_rate: New sample rate in Hz.
|
||||
:type new_sample_rate: int
|
||||
"""
|
||||
|
||||
def __init__(self, rng, new_sample_rate):
|
||||
self._new_sample_rate = new_sample_rate
|
||||
self._rng = rng
|
||||
|
||||
def transform_audio(self, audio_segment):
|
||||
"""Resamples the input audio to a target sample rate.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param audio: Audio segment to add effects to.
|
||||
:type audio: AudioSegment|SpeechSegment
|
||||
"""
|
||||
audio_segment.resample(self._new_sample_rate)
|
@ -0,0 +1,34 @@
|
||||
"""Contains the volume perturb augmentation model."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from data_utils.augmentor.base import AugmentorBase
|
||||
|
||||
|
||||
class ShiftPerturbAugmentor(AugmentorBase):
|
||||
"""Augmentation model for adding random shift perturbation.
|
||||
|
||||
:param rng: Random generator object.
|
||||
:type rng: random.Random
|
||||
:param min_shift_ms: Minimal shift in milliseconds.
|
||||
:type min_shift_ms: float
|
||||
:param max_shift_ms: Maximal shift in milliseconds.
|
||||
:type max_shift_ms: float
|
||||
"""
|
||||
|
||||
def __init__(self, rng, min_shift_ms, max_shift_ms):
|
||||
self._min_shift_ms = min_shift_ms
|
||||
self._max_shift_ms = max_shift_ms
|
||||
self._rng = rng
|
||||
|
||||
def transform_audio(self, audio_segment):
|
||||
"""Shift audio.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param audio_segment: Audio segment to add effects to.
|
||||
:type audio_segment: AudioSegmenet|SpeechSegment
|
||||
"""
|
||||
shift_ms = self._rng.uniform(self._min_shift_ms, self._max_shift_ms)
|
||||
audio_segment.shift(shift_ms)
|
@ -0,0 +1,47 @@
|
||||
"""Contain the speech perturbation augmentation model."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from data_utils.augmentor.base import AugmentorBase
|
||||
|
||||
|
||||
class SpeedPerturbAugmentor(AugmentorBase):
|
||||
"""Augmentation model for adding speed perturbation.
|
||||
|
||||
See reference paper here:
|
||||
http://www.danielpovey.com/files/2015_interspeech_augmentation.pdf
|
||||
|
||||
:param rng: Random generator object.
|
||||
:type rng: random.Random
|
||||
:param min_speed_rate: Lower bound of new speed rate to sample and should
|
||||
not be smaller than 0.9.
|
||||
:type min_speed_rate: float
|
||||
:param max_speed_rate: Upper bound of new speed rate to sample and should
|
||||
not be larger than 1.1.
|
||||
:type max_speed_rate: float
|
||||
"""
|
||||
|
||||
def __init__(self, rng, min_speed_rate, max_speed_rate):
|
||||
if min_speed_rate < 0.9:
|
||||
raise ValueError(
|
||||
"Sampling speed below 0.9 can cause unnatural effects")
|
||||
if max_speed_rate > 1.1:
|
||||
raise ValueError(
|
||||
"Sampling speed above 1.1 can cause unnatural effects")
|
||||
self._min_speed_rate = min_speed_rate
|
||||
self._max_speed_rate = max_speed_rate
|
||||
self._rng = rng
|
||||
|
||||
def transform_audio(self, audio_segment):
|
||||
"""Sample a new speed rate from the given range and
|
||||
changes the speed of the given audio clip.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param audio_segment: Audio segment to add effects to.
|
||||
:type audio_segment: AudioSegment|SpeechSegment
|
||||
"""
|
||||
sampled_speed = self._rng.uniform(self._min_speed_rate,
|
||||
self._max_speed_rate)
|
||||
audio_segment.change_speed(sampled_speed)
|
@ -0,0 +1,40 @@
|
||||
"""Contains the volume perturb augmentation model."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from data_utils.augmentor.base import AugmentorBase
|
||||
|
||||
|
||||
class VolumePerturbAugmentor(AugmentorBase):
|
||||
"""Augmentation model for adding random volume perturbation.
|
||||
|
||||
This is used for multi-loudness training of PCEN. See
|
||||
|
||||
https://arxiv.org/pdf/1607.05666v1.pdf
|
||||
|
||||
for more details.
|
||||
|
||||
:param rng: Random generator object.
|
||||
:type rng: random.Random
|
||||
:param min_gain_dBFS: Minimal gain in dBFS.
|
||||
:type min_gain_dBFS: float
|
||||
:param max_gain_dBFS: Maximal gain in dBFS.
|
||||
:type max_gain_dBFS: float
|
||||
"""
|
||||
|
||||
def __init__(self, rng, min_gain_dBFS, max_gain_dBFS):
|
||||
self._min_gain_dBFS = min_gain_dBFS
|
||||
self._max_gain_dBFS = max_gain_dBFS
|
||||
self._rng = rng
|
||||
|
||||
def transform_audio(self, audio_segment):
|
||||
"""Change audio loadness.
|
||||
|
||||
Note that this is an in-place transformation.
|
||||
|
||||
:param audio_segment: Audio segment to add effects to.
|
||||
:type audio_segment: AudioSegmenet|SpeechSegment
|
||||
"""
|
||||
gain = self._rng.uniform(self._min_gain_dBFS, self._max_gain_dBFS)
|
||||
audio_segment.gain_db(gain)
|
@ -0,0 +1,346 @@
|
||||
"""Contains data generator for orgnaizing various audio data preprocessing
|
||||
pipeline and offering data reader interface of PaddlePaddle requirements.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
import tarfile
|
||||
import multiprocessing
|
||||
import numpy as np
|
||||
import paddle.v2 as paddle
|
||||
from threading import local
|
||||
from data_utils.utility import read_manifest
|
||||
from data_utils.utility import xmap_readers_mp
|
||||
from data_utils.augmentor.augmentation import AugmentationPipeline
|
||||
from data_utils.featurizer.speech_featurizer import SpeechFeaturizer
|
||||
from data_utils.speech import SpeechSegment
|
||||
from data_utils.normalizer import FeatureNormalizer
|
||||
|
||||
|
||||
class DataGenerator(object):
|
||||
"""
|
||||
DataGenerator provides basic audio data preprocessing pipeline, and offers
|
||||
data reader interfaces of PaddlePaddle requirements.
|
||||
|
||||
:param vocab_filepath: Vocabulary filepath for indexing tokenized
|
||||
transcripts.
|
||||
:type vocab_filepath: basestring
|
||||
:param mean_std_filepath: File containing the pre-computed mean and stddev.
|
||||
:type mean_std_filepath: None|basestring
|
||||
:param augmentation_config: Augmentation configuration in json string.
|
||||
Details see AugmentationPipeline.__doc__.
|
||||
:type augmentation_config: str
|
||||
:param max_duration: Audio with duration (in seconds) greater than
|
||||
this will be discarded.
|
||||
:type max_duration: float
|
||||
:param min_duration: Audio with duration (in seconds) smaller than
|
||||
this will be discarded.
|
||||
:type min_duration: float
|
||||
:param stride_ms: Striding size (in milliseconds) for generating frames.
|
||||
:type stride_ms: float
|
||||
:param window_ms: Window size (in milliseconds) for generating frames.
|
||||
:type window_ms: float
|
||||
:param max_freq: Used when specgram_type is 'linear', only FFT bins
|
||||
corresponding to frequencies between [0, max_freq] are
|
||||
returned.
|
||||
:types max_freq: None|float
|
||||
:param specgram_type: Specgram feature type. Options: 'linear'.
|
||||
:type specgram_type: str
|
||||
:param use_dB_normalization: Whether to normalize the audio to -20 dB
|
||||
before extracting the features.
|
||||
:type use_dB_normalization: bool
|
||||
:param num_threads: Number of CPU threads for processing data.
|
||||
:type num_threads: int
|
||||
:param random_seed: Random seed.
|
||||
:type random_seed: int
|
||||
:param keep_transcription_text: If set to True, transcription text will
|
||||
be passed forward directly without
|
||||
converting to index sequence.
|
||||
:type keep_transcription_text: bool
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
vocab_filepath,
|
||||
mean_std_filepath,
|
||||
augmentation_config='{}',
|
||||
max_duration=float('inf'),
|
||||
min_duration=0.0,
|
||||
stride_ms=10.0,
|
||||
window_ms=20.0,
|
||||
max_freq=None,
|
||||
specgram_type='linear',
|
||||
use_dB_normalization=True,
|
||||
num_threads=multiprocessing.cpu_count() // 2,
|
||||
random_seed=0,
|
||||
keep_transcription_text=False):
|
||||
self._max_duration = max_duration
|
||||
self._min_duration = min_duration
|
||||
self._normalizer = FeatureNormalizer(mean_std_filepath)
|
||||
self._augmentation_pipeline = AugmentationPipeline(
|
||||
augmentation_config=augmentation_config, random_seed=random_seed)
|
||||
self._speech_featurizer = SpeechFeaturizer(
|
||||
vocab_filepath=vocab_filepath,
|
||||
specgram_type=specgram_type,
|
||||
stride_ms=stride_ms,
|
||||
window_ms=window_ms,
|
||||
max_freq=max_freq,
|
||||
use_dB_normalization=use_dB_normalization)
|
||||
self._num_threads = num_threads
|
||||
self._rng = random.Random(random_seed)
|
||||
self._keep_transcription_text = keep_transcription_text
|
||||
self._epoch = 0
|
||||
# for caching tar files info
|
||||
self._local_data = local()
|
||||
self._local_data.tar2info = {}
|
||||
self._local_data.tar2object = {}
|
||||
|
||||
def process_utterance(self, audio_file, transcript):
|
||||
"""Load, augment, featurize and normalize for speech data.
|
||||
|
||||
:param audio_file: Filepath or file object of audio file.
|
||||
:type audio_file: basestring | file
|
||||
:param transcript: Transcription text.
|
||||
:type transcript: basestring
|
||||
:return: Tuple of audio feature tensor and data of transcription part,
|
||||
where transcription part could be token ids or text.
|
||||
:rtype: tuple of (2darray, list)
|
||||
"""
|
||||
if isinstance(audio_file, basestring) and audio_file.startswith('tar:'):
|
||||
speech_segment = SpeechSegment.from_file(
|
||||
self._subfile_from_tar(audio_file), transcript)
|
||||
else:
|
||||
speech_segment = SpeechSegment.from_file(audio_file, transcript)
|
||||
self._augmentation_pipeline.transform_audio(speech_segment)
|
||||
specgram, transcript_part = self._speech_featurizer.featurize(
|
||||
speech_segment, self._keep_transcription_text)
|
||||
specgram = self._normalizer.apply(specgram)
|
||||
return specgram, transcript_part
|
||||
|
||||
def batch_reader_creator(self,
|
||||
manifest_path,
|
||||
batch_size,
|
||||
min_batch_size=1,
|
||||
padding_to=-1,
|
||||
flatten=False,
|
||||
sortagrad=False,
|
||||
shuffle_method="batch_shuffle"):
|
||||
"""
|
||||
Batch data reader creator for audio data. Return a callable generator
|
||||
function to produce batches of data.
|
||||
|
||||
Audio features within one batch will be padded with zeros to have the
|
||||
same shape, or a user-defined shape.
|
||||
|
||||
:param manifest_path: Filepath of manifest for audio files.
|
||||
:type manifest_path: basestring
|
||||
:param batch_size: Number of instances in a batch.
|
||||
:type batch_size: int
|
||||
:param min_batch_size: Any batch with batch size smaller than this will
|
||||
be discarded. (To be deprecated in the future.)
|
||||
:type min_batch_size: int
|
||||
:param padding_to: If set -1, the maximun shape in the batch
|
||||
will be used as the target shape for padding.
|
||||
Otherwise, `padding_to` will be the target shape.
|
||||
:type padding_to: int
|
||||
:param flatten: If set True, audio features will be flatten to 1darray.
|
||||
:type flatten: bool
|
||||
:param sortagrad: If set True, sort the instances by audio duration
|
||||
in the first epoch for speed up training.
|
||||
:type sortagrad: bool
|
||||
:param shuffle_method: Shuffle method. Options:
|
||||
'' or None: no shuffle.
|
||||
'instance_shuffle': instance-wise shuffle.
|
||||
'batch_shuffle': similarly-sized instances are
|
||||
put into batches, and then
|
||||
batch-wise shuffle the batches.
|
||||
For more details, please see
|
||||
``_batch_shuffle.__doc__``.
|
||||
'batch_shuffle_clipped': 'batch_shuffle' with
|
||||
head shift and tail
|
||||
clipping. For more
|
||||
details, please see
|
||||
``_batch_shuffle``.
|
||||
If sortagrad is True, shuffle is disabled
|
||||
for the first epoch.
|
||||
:type shuffle_method: None|str
|
||||
:return: Batch reader function, producing batches of data when called.
|
||||
:rtype: callable
|
||||
"""
|
||||
|
||||
def batch_reader():
|
||||
# read manifest
|
||||
manifest = read_manifest(
|
||||
manifest_path=manifest_path,
|
||||
max_duration=self._max_duration,
|
||||
min_duration=self._min_duration)
|
||||
# sort (by duration) or batch-wise shuffle the manifest
|
||||
if self._epoch == 0 and sortagrad:
|
||||
manifest.sort(key=lambda x: x["duration"])
|
||||
else:
|
||||
if shuffle_method == "batch_shuffle":
|
||||
manifest = self._batch_shuffle(
|
||||
manifest, batch_size, clipped=False)
|
||||
elif shuffle_method == "batch_shuffle_clipped":
|
||||
manifest = self._batch_shuffle(
|
||||
manifest, batch_size, clipped=True)
|
||||
elif shuffle_method == "instance_shuffle":
|
||||
self._rng.shuffle(manifest)
|
||||
elif shuffle_method == None:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Unknown shuffle method %s." %
|
||||
shuffle_method)
|
||||
# prepare batches
|
||||
instance_reader, cleanup = self._instance_reader_creator(manifest)
|
||||
batch = []
|
||||
try:
|
||||
for instance in instance_reader():
|
||||
batch.append(instance)
|
||||
if len(batch) == batch_size:
|
||||
yield self._padding_batch(batch, padding_to, flatten)
|
||||
batch = []
|
||||
if len(batch) >= min_batch_size:
|
||||
yield self._padding_batch(batch, padding_to, flatten)
|
||||
finally:
|
||||
cleanup()
|
||||
self._epoch += 1
|
||||
|
||||
return batch_reader
|
||||
|
||||
@property
|
||||
def feeding(self):
|
||||
"""Returns data reader's feeding dict.
|
||||
|
||||
:return: Data feeding dict.
|
||||
:rtype: dict
|
||||
"""
|
||||
feeding_dict = {"audio_spectrogram": 0, "transcript_text": 1}
|
||||
return feeding_dict
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
"""Return the vocabulary size.
|
||||
|
||||
:return: Vocabulary size.
|
||||
:rtype: int
|
||||
"""
|
||||
return self._speech_featurizer.vocab_size
|
||||
|
||||
@property
|
||||
def vocab_list(self):
|
||||
"""Return the vocabulary in list.
|
||||
|
||||
:return: Vocabulary in list.
|
||||
:rtype: list
|
||||
"""
|
||||
return self._speech_featurizer.vocab_list
|
||||
|
||||
def _parse_tar(self, file):
|
||||
"""Parse a tar file to get a tarfile object
|
||||
and a map containing tarinfoes
|
||||
"""
|
||||
result = {}
|
||||
f = tarfile.open(file)
|
||||
for tarinfo in f.getmembers():
|
||||
result[tarinfo.name] = tarinfo
|
||||
return f, result
|
||||
|
||||
def _subfile_from_tar(self, file):
|
||||
"""Get subfile object from tar.
|
||||
|
||||
It will return a subfile object from tar file
|
||||
and cached tar file info for next reading request.
|
||||
"""
|
||||
tarpath, filename = file.split(':', 1)[1].split('#', 1)
|
||||
if 'tar2info' not in self._local_data.__dict__:
|
||||
self._local_data.tar2info = {}
|
||||
if 'tar2object' not in self._local_data.__dict__:
|
||||
self._local_data.tar2object = {}
|
||||
if tarpath not in self._local_data.tar2info:
|
||||
object, infoes = self._parse_tar(tarpath)
|
||||
self._local_data.tar2info[tarpath] = infoes
|
||||
self._local_data.tar2object[tarpath] = object
|
||||
return self._local_data.tar2object[tarpath].extractfile(
|
||||
self._local_data.tar2info[tarpath][filename])
|
||||
|
||||
def _instance_reader_creator(self, manifest):
|
||||
"""
|
||||
Instance reader creator. Create a callable function to produce
|
||||
instances of data.
|
||||
|
||||
Instance: a tuple of ndarray of audio spectrogram and a list of
|
||||
token indices for transcript.
|
||||
"""
|
||||
|
||||
def reader():
|
||||
for instance in manifest:
|
||||
yield instance
|
||||
|
||||
reader, cleanup_callback = xmap_readers_mp(
|
||||
lambda instance: self.process_utterance(instance["audio_filepath"], instance["text"]),
|
||||
reader, self._num_threads, 4096)
|
||||
|
||||
return reader, cleanup_callback
|
||||
|
||||
def _padding_batch(self, batch, padding_to=-1, flatten=False):
|
||||
"""
|
||||
Padding audio features with zeros to make them have the same shape (or
|
||||
a user-defined shape) within one bach.
|
||||
|
||||
If ``padding_to`` is -1, the maximun shape in the batch will be used
|
||||
as the target shape for padding. Otherwise, `padding_to` will be the
|
||||
target shape (only refers to the second axis).
|
||||
|
||||
If `flatten` is True, features will be flatten to 1darray.
|
||||
"""
|
||||
new_batch = []
|
||||
# get target shape
|
||||
max_length = max([audio.shape[1] for audio, text in batch])
|
||||
if padding_to != -1:
|
||||
if padding_to < max_length:
|
||||
raise ValueError("If padding_to is not -1, it should be larger "
|
||||
"than any instance's shape in the batch")
|
||||
max_length = padding_to
|
||||
# padding
|
||||
for audio, text in batch:
|
||||
padded_audio = np.zeros([audio.shape[0], max_length])
|
||||
padded_audio[:, :audio.shape[1]] = audio
|
||||
if flatten:
|
||||
padded_audio = padded_audio.flatten()
|
||||
padded_instance = [padded_audio, text, audio.shape[1]]
|
||||
new_batch.append(padded_instance)
|
||||
return new_batch
|
||||
|
||||
def _batch_shuffle(self, manifest, batch_size, clipped=False):
|
||||
"""Put similarly-sized instances into minibatches for better efficiency
|
||||
and make a batch-wise shuffle.
|
||||
|
||||
1. Sort the audio clips by duration.
|
||||
2. Generate a random number `k`, k in [0, batch_size).
|
||||
3. Randomly shift `k` instances in order to create different batches
|
||||
for different epochs. Create minibatches.
|
||||
4. Shuffle the minibatches.
|
||||
|
||||
:param manifest: Manifest contents. List of dict.
|
||||
:type manifest: list
|
||||
:param batch_size: Batch size. This size is also used for generate
|
||||
a random number for batch shuffle.
|
||||
:type batch_size: int
|
||||
:param clipped: Whether to clip the heading (small shift) and trailing
|
||||
(incomplete batch) instances.
|
||||
:type clipped: bool
|
||||
:return: Batch shuffled mainifest.
|
||||
:rtype: list
|
||||
"""
|
||||
manifest.sort(key=lambda x: x["duration"])
|
||||
shift_len = self._rng.randint(0, batch_size - 1)
|
||||
batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size)
|
||||
self._rng.shuffle(batch_manifest)
|
||||
batch_manifest = [item for batch in batch_manifest for item in batch]
|
||||
if not clipped:
|
||||
res_len = len(manifest) - shift_len - len(batch_manifest)
|
||||
batch_manifest.extend(manifest[-res_len:])
|
||||
batch_manifest.extend(manifest[0:shift_len])
|
||||
return batch_manifest
|
@ -0,0 +1,187 @@
|
||||
"""Contains the audio featurizer class."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from data_utils.utility import read_manifest
|
||||
from data_utils.audio import AudioSegment
|
||||
from python_speech_features import mfcc
|
||||
from python_speech_features import delta
|
||||
|
||||
|
||||
class AudioFeaturizer(object):
|
||||
"""Audio featurizer, for extracting features from audio contents of
|
||||
AudioSegment or SpeechSegment.
|
||||
|
||||
Currently, it supports feature types of linear spectrogram and mfcc.
|
||||
|
||||
:param specgram_type: Specgram feature type. Options: 'linear'.
|
||||
:type specgram_type: str
|
||||
:param stride_ms: Striding size (in milliseconds) for generating frames.
|
||||
:type stride_ms: float
|
||||
:param window_ms: Window size (in milliseconds) for generating frames.
|
||||
:type window_ms: float
|
||||
:param max_freq: When specgram_type is 'linear', only FFT bins
|
||||
corresponding to frequencies between [0, max_freq] are
|
||||
returned; when specgram_type is 'mfcc', max_feq is the
|
||||
highest band edge of mel filters.
|
||||
:types max_freq: None|float
|
||||
:param target_sample_rate: Audio are resampled (if upsampling or
|
||||
downsampling is allowed) to this before
|
||||
extracting spectrogram features.
|
||||
:type target_sample_rate: float
|
||||
:param use_dB_normalization: Whether to normalize the audio to a certain
|
||||
decibels before extracting the features.
|
||||
:type use_dB_normalization: bool
|
||||
:param target_dB: Target audio decibels for normalization.
|
||||
:type target_dB: float
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
specgram_type='linear',
|
||||
stride_ms=10.0,
|
||||
window_ms=20.0,
|
||||
max_freq=None,
|
||||
target_sample_rate=16000,
|
||||
use_dB_normalization=True,
|
||||
target_dB=-20):
|
||||
self._specgram_type = specgram_type
|
||||
self._stride_ms = stride_ms
|
||||
self._window_ms = window_ms
|
||||
self._max_freq = max_freq
|
||||
self._target_sample_rate = target_sample_rate
|
||||
self._use_dB_normalization = use_dB_normalization
|
||||
self._target_dB = target_dB
|
||||
|
||||
def featurize(self,
|
||||
audio_segment,
|
||||
allow_downsampling=True,
|
||||
allow_upsampling=True):
|
||||
"""Extract audio features from AudioSegment or SpeechSegment.
|
||||
|
||||
:param audio_segment: Audio/speech segment to extract features from.
|
||||
:type audio_segment: AudioSegment|SpeechSegment
|
||||
:param allow_downsampling: Whether to allow audio downsampling before
|
||||
featurizing.
|
||||
:type allow_downsampling: bool
|
||||
:param allow_upsampling: Whether to allow audio upsampling before
|
||||
featurizing.
|
||||
:type allow_upsampling: bool
|
||||
:return: Spectrogram audio feature in 2darray.
|
||||
:rtype: ndarray
|
||||
:raises ValueError: If audio sample rate is not supported.
|
||||
"""
|
||||
# upsampling or downsampling
|
||||
if ((audio_segment.sample_rate > self._target_sample_rate and
|
||||
allow_downsampling) or
|
||||
(audio_segment.sample_rate < self._target_sample_rate and
|
||||
allow_upsampling)):
|
||||
audio_segment.resample(self._target_sample_rate)
|
||||
if audio_segment.sample_rate != self._target_sample_rate:
|
||||
raise ValueError("Audio sample rate is not supported. "
|
||||
"Turn allow_downsampling or allow up_sampling on.")
|
||||
# decibel normalization
|
||||
if self._use_dB_normalization:
|
||||
audio_segment.normalize(target_db=self._target_dB)
|
||||
# extract spectrogram
|
||||
return self._compute_specgram(audio_segment.samples,
|
||||
audio_segment.sample_rate)
|
||||
|
||||
def _compute_specgram(self, samples, sample_rate):
|
||||
"""Extract various audio features."""
|
||||
if self._specgram_type == 'linear':
|
||||
return self._compute_linear_specgram(
|
||||
samples, sample_rate, self._stride_ms, self._window_ms,
|
||||
self._max_freq)
|
||||
elif self._specgram_type == 'mfcc':
|
||||
return self._compute_mfcc(samples, sample_rate, self._stride_ms,
|
||||
self._window_ms, self._max_freq)
|
||||
else:
|
||||
raise ValueError("Unknown specgram_type %s. "
|
||||
"Supported values: linear." % self._specgram_type)
|
||||
|
||||
def _compute_linear_specgram(self,
|
||||
samples,
|
||||
sample_rate,
|
||||
stride_ms=10.0,
|
||||
window_ms=20.0,
|
||||
max_freq=None,
|
||||
eps=1e-14):
|
||||
"""Compute the linear spectrogram from FFT energy."""
|
||||
if max_freq is None:
|
||||
max_freq = sample_rate / 2
|
||||
if max_freq > sample_rate / 2:
|
||||
raise ValueError("max_freq must not be greater than half of "
|
||||
"sample rate.")
|
||||
if stride_ms > window_ms:
|
||||
raise ValueError("Stride size must not be greater than "
|
||||
"window size.")
|
||||
stride_size = int(0.001 * sample_rate * stride_ms)
|
||||
window_size = int(0.001 * sample_rate * window_ms)
|
||||
specgram, freqs = self._specgram_real(
|
||||
samples,
|
||||
window_size=window_size,
|
||||
stride_size=stride_size,
|
||||
sample_rate=sample_rate)
|
||||
ind = np.where(freqs <= max_freq)[0][-1] + 1
|
||||
return np.log(specgram[:ind, :] + eps)
|
||||
|
||||
def _specgram_real(self, samples, window_size, stride_size, sample_rate):
|
||||
"""Compute the spectrogram for samples from a real signal."""
|
||||
# extract strided windows
|
||||
truncate_size = (len(samples) - window_size) % stride_size
|
||||
samples = samples[:len(samples) - truncate_size]
|
||||
nshape = (window_size, (len(samples) - window_size) // stride_size + 1)
|
||||
nstrides = (samples.strides[0], samples.strides[0] * stride_size)
|
||||
windows = np.lib.stride_tricks.as_strided(
|
||||
samples, shape=nshape, strides=nstrides)
|
||||
assert np.all(
|
||||
windows[:, 1] == samples[stride_size:(stride_size + window_size)])
|
||||
# window weighting, squared Fast Fourier Transform (fft), scaling
|
||||
weighting = np.hanning(window_size)[:, None]
|
||||
fft = np.fft.rfft(windows * weighting, axis=0)
|
||||
fft = np.absolute(fft)
|
||||
fft = fft**2
|
||||
scale = np.sum(weighting**2) * sample_rate
|
||||
fft[1:-1, :] *= (2.0 / scale)
|
||||
fft[(0, -1), :] /= scale
|
||||
# prepare fft frequency list
|
||||
freqs = float(sample_rate) / window_size * np.arange(fft.shape[0])
|
||||
return fft, freqs
|
||||
|
||||
def _compute_mfcc(self,
|
||||
samples,
|
||||
sample_rate,
|
||||
stride_ms=10.0,
|
||||
window_ms=20.0,
|
||||
max_freq=None):
|
||||
"""Compute mfcc from samples."""
|
||||
if max_freq is None:
|
||||
max_freq = sample_rate / 2
|
||||
if max_freq > sample_rate / 2:
|
||||
raise ValueError("max_freq must not be greater than half of "
|
||||
"sample rate.")
|
||||
if stride_ms > window_ms:
|
||||
raise ValueError("Stride size must not be greater than "
|
||||
"window size.")
|
||||
# compute the 13 cepstral coefficients, and the first one is replaced
|
||||
# by log(frame energy)
|
||||
mfcc_feat = mfcc(
|
||||
signal=samples,
|
||||
samplerate=sample_rate,
|
||||
winlen=0.001 * window_ms,
|
||||
winstep=0.001 * stride_ms,
|
||||
highfreq=max_freq)
|
||||
# Deltas
|
||||
d_mfcc_feat = delta(mfcc_feat, 2)
|
||||
# Deltas-Deltas
|
||||
dd_mfcc_feat = delta(d_mfcc_feat, 2)
|
||||
# transpose
|
||||
mfcc_feat = np.transpose(mfcc_feat)
|
||||
d_mfcc_feat = np.transpose(d_mfcc_feat)
|
||||
dd_mfcc_feat = np.transpose(dd_mfcc_feat)
|
||||
# concat above three features
|
||||
concat_mfcc_feat = np.concatenate(
|
||||
(mfcc_feat, d_mfcc_feat, dd_mfcc_feat))
|
||||
return concat_mfcc_feat
|
@ -0,0 +1,98 @@
|
||||
"""Contains the speech featurizer class."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from data_utils.featurizer.audio_featurizer import AudioFeaturizer
|
||||
from data_utils.featurizer.text_featurizer import TextFeaturizer
|
||||
|
||||
|
||||
class SpeechFeaturizer(object):
|
||||
"""Speech featurizer, for extracting features from both audio and transcript
|
||||
contents of SpeechSegment.
|
||||
|
||||
Currently, for audio parts, it supports feature types of linear
|
||||
spectrogram and mfcc; for transcript parts, it only supports char-level
|
||||
tokenizing and conversion into a list of token indices. Note that the
|
||||
token indexing order follows the given vocabulary file.
|
||||
|
||||
:param vocab_filepath: Filepath to load vocabulary for token indices
|
||||
conversion.
|
||||
:type specgram_type: basestring
|
||||
:param specgram_type: Specgram feature type. Options: 'linear', 'mfcc'.
|
||||
:type specgram_type: str
|
||||
:param stride_ms: Striding size (in milliseconds) for generating frames.
|
||||
:type stride_ms: float
|
||||
:param window_ms: Window size (in milliseconds) for generating frames.
|
||||
:type window_ms: float
|
||||
:param max_freq: When specgram_type is 'linear', only FFT bins
|
||||
corresponding to frequencies between [0, max_freq] are
|
||||
returned; when specgram_type is 'mfcc', max_freq is the
|
||||
highest band edge of mel filters.
|
||||
:types max_freq: None|float
|
||||
:param target_sample_rate: Speech are resampled (if upsampling or
|
||||
downsampling is allowed) to this before
|
||||
extracting spectrogram features.
|
||||
:type target_sample_rate: float
|
||||
:param use_dB_normalization: Whether to normalize the audio to a certain
|
||||
decibels before extracting the features.
|
||||
:type use_dB_normalization: bool
|
||||
:param target_dB: Target audio decibels for normalization.
|
||||
:type target_dB: float
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
vocab_filepath,
|
||||
specgram_type='linear',
|
||||
stride_ms=10.0,
|
||||
window_ms=20.0,
|
||||
max_freq=None,
|
||||
target_sample_rate=16000,
|
||||
use_dB_normalization=True,
|
||||
target_dB=-20):
|
||||
self._audio_featurizer = AudioFeaturizer(
|
||||
specgram_type=specgram_type,
|
||||
stride_ms=stride_ms,
|
||||
window_ms=window_ms,
|
||||
max_freq=max_freq,
|
||||
target_sample_rate=target_sample_rate,
|
||||
use_dB_normalization=use_dB_normalization,
|
||||
target_dB=target_dB)
|
||||
self._text_featurizer = TextFeaturizer(vocab_filepath)
|
||||
|
||||
def featurize(self, speech_segment, keep_transcription_text):
|
||||
"""Extract features for speech segment.
|
||||
|
||||
1. For audio parts, extract the audio features.
|
||||
2. For transcript parts, keep the original text or convert text string
|
||||
to a list of token indices in char-level.
|
||||
|
||||
:param audio_segment: Speech segment to extract features from.
|
||||
:type audio_segment: SpeechSegment
|
||||
:return: A tuple of 1) spectrogram audio feature in 2darray, 2) list of
|
||||
char-level token indices.
|
||||
:rtype: tuple
|
||||
"""
|
||||
audio_feature = self._audio_featurizer.featurize(speech_segment)
|
||||
if keep_transcription_text:
|
||||
return audio_feature, speech_segment.transcript
|
||||
text_ids = self._text_featurizer.featurize(speech_segment.transcript)
|
||||
return audio_feature, text_ids
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
"""Return the vocabulary size.
|
||||
|
||||
:return: Vocabulary size.
|
||||
:rtype: int
|
||||
"""
|
||||
return self._text_featurizer.vocab_size
|
||||
|
||||
@property
|
||||
def vocab_list(self):
|
||||
"""Return the vocabulary in list.
|
||||
|
||||
:return: Vocabulary in list.
|
||||
:rtype: list
|
||||
"""
|
||||
return self._text_featurizer.vocab_list
|
@ -0,0 +1,68 @@
|
||||
"""Contains the text featurizer class."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import codecs
|
||||
|
||||
|
||||
class TextFeaturizer(object):
|
||||
"""Text featurizer, for processing or extracting features from text.
|
||||
|
||||
Currently, it only supports char-level tokenizing and conversion into
|
||||
a list of token indices. Note that the token indexing order follows the
|
||||
given vocabulary file.
|
||||
|
||||
:param vocab_filepath: Filepath to load vocabulary for token indices
|
||||
conversion.
|
||||
:type specgram_type: basestring
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_filepath):
|
||||
self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file(
|
||||
vocab_filepath)
|
||||
|
||||
def featurize(self, text):
|
||||
"""Convert text string to a list of token indices in char-level.Note
|
||||
that the token indexing order follows the given vocabulary file.
|
||||
|
||||
:param text: Text to process.
|
||||
:type text: basestring
|
||||
:return: List of char-level token indices.
|
||||
:rtype: list
|
||||
"""
|
||||
tokens = self._char_tokenize(text)
|
||||
return [self._vocab_dict[token] for token in tokens]
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
"""Return the vocabulary size.
|
||||
|
||||
:return: Vocabulary size.
|
||||
:rtype: int
|
||||
"""
|
||||
return len(self._vocab_list)
|
||||
|
||||
@property
|
||||
def vocab_list(self):
|
||||
"""Return the vocabulary in list.
|
||||
|
||||
:return: Vocabulary in list.
|
||||
:rtype: list
|
||||
"""
|
||||
return self._vocab_list
|
||||
|
||||
def _char_tokenize(self, text):
|
||||
"""Character tokenizer."""
|
||||
return list(text.strip())
|
||||
|
||||
def _load_vocabulary_from_file(self, vocab_filepath):
|
||||
"""Load vocabulary from file."""
|
||||
vocab_lines = []
|
||||
with codecs.open(vocab_filepath, 'r', 'utf-8') as file:
|
||||
vocab_lines.extend(file.readlines())
|
||||
vocab_list = [line[:-1] for line in vocab_lines]
|
||||
vocab_dict = dict(
|
||||
[(token, id) for (id, token) in enumerate(vocab_list)])
|
||||
return vocab_dict, vocab_list
|
@ -0,0 +1,87 @@
|
||||
"""Contains feature normalizers."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import random
|
||||
from data_utils.utility import read_manifest
|
||||
from data_utils.audio import AudioSegment
|
||||
|
||||
|
||||
class FeatureNormalizer(object):
|
||||
"""Feature normalizer. Normalize features to be of zero mean and unit
|
||||
stddev.
|
||||
|
||||
if mean_std_filepath is provided (not None), the normalizer will directly
|
||||
initilize from the file. Otherwise, both manifest_path and featurize_func
|
||||
should be given for on-the-fly mean and stddev computing.
|
||||
|
||||
:param mean_std_filepath: File containing the pre-computed mean and stddev.
|
||||
:type mean_std_filepath: None|basestring
|
||||
:param manifest_path: Manifest of instances for computing mean and stddev.
|
||||
:type meanifest_path: None|basestring
|
||||
:param featurize_func: Function to extract features. It should be callable
|
||||
with ``featurize_func(audio_segment)``.
|
||||
:type featurize_func: None|callable
|
||||
:param num_samples: Number of random samples for computing mean and stddev.
|
||||
:type num_samples: int
|
||||
:param random_seed: Random seed for sampling instances.
|
||||
:type random_seed: int
|
||||
:raises ValueError: If both mean_std_filepath and manifest_path
|
||||
(or both mean_std_filepath and featurize_func) are None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mean_std_filepath,
|
||||
manifest_path=None,
|
||||
featurize_func=None,
|
||||
num_samples=500,
|
||||
random_seed=0):
|
||||
if not mean_std_filepath:
|
||||
if not (manifest_path and featurize_func):
|
||||
raise ValueError("If mean_std_filepath is None, meanifest_path "
|
||||
"and featurize_func should not be None.")
|
||||
self._rng = random.Random(random_seed)
|
||||
self._compute_mean_std(manifest_path, featurize_func, num_samples)
|
||||
else:
|
||||
self._read_mean_std_from_file(mean_std_filepath)
|
||||
|
||||
def apply(self, features, eps=1e-14):
|
||||
"""Normalize features to be of zero mean and unit stddev.
|
||||
|
||||
:param features: Input features to be normalized.
|
||||
:type features: ndarray
|
||||
:param eps: added to stddev to provide numerical stablibity.
|
||||
:type eps: float
|
||||
:return: Normalized features.
|
||||
:rtype: ndarray
|
||||
"""
|
||||
return (features - self._mean) / (self._std + eps)
|
||||
|
||||
def write_to_file(self, filepath):
|
||||
"""Write the mean and stddev to the file.
|
||||
|
||||
:param filepath: File to write mean and stddev.
|
||||
:type filepath: basestring
|
||||
"""
|
||||
np.savez(filepath, mean=self._mean, std=self._std)
|
||||
|
||||
def _read_mean_std_from_file(self, filepath):
|
||||
"""Load mean and std from file."""
|
||||
npzfile = np.load(filepath)
|
||||
self._mean = npzfile["mean"]
|
||||
self._std = npzfile["std"]
|
||||
|
||||
def _compute_mean_std(self, manifest_path, featurize_func, num_samples):
|
||||
"""Compute mean and std from randomly sampled instances."""
|
||||
manifest = read_manifest(manifest_path)
|
||||
sampled_manifest = self._rng.sample(manifest, num_samples)
|
||||
features = []
|
||||
for instance in sampled_manifest:
|
||||
features.append(
|
||||
featurize_func(
|
||||
AudioSegment.from_file(instance["audio_filepath"])))
|
||||
features = np.hstack(features)
|
||||
self._mean = np.mean(features, axis=1).reshape([-1, 1])
|
||||
self._std = np.std(features, axis=1).reshape([-1, 1])
|
@ -0,0 +1,143 @@
|
||||
"""Contains the speech segment class."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from data_utils.audio import AudioSegment
|
||||
|
||||
|
||||
class SpeechSegment(AudioSegment):
|
||||
"""Speech segment abstraction, a subclass of AudioSegment,
|
||||
with an additional transcript.
|
||||
|
||||
:param samples: Audio samples [num_samples x num_channels].
|
||||
:type samples: ndarray.float32
|
||||
:param sample_rate: Audio sample rate.
|
||||
:type sample_rate: int
|
||||
:param transcript: Transcript text for the speech.
|
||||
:type transript: basestring
|
||||
:raises TypeError: If the sample data type is not float or int.
|
||||
"""
|
||||
|
||||
def __init__(self, samples, sample_rate, transcript):
|
||||
AudioSegment.__init__(self, samples, sample_rate)
|
||||
self._transcript = transcript
|
||||
|
||||
def __eq__(self, other):
|
||||
"""Return whether two objects are equal.
|
||||
"""
|
||||
if not AudioSegment.__eq__(self, other):
|
||||
return False
|
||||
if self._transcript != other._transcript:
|
||||
return False
|
||||
return True
|
||||
|
||||
def __ne__(self, other):
|
||||
"""Return whether two objects are unequal."""
|
||||
return not self.__eq__(other)
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, filepath, transcript):
|
||||
"""Create speech segment from audio file and corresponding transcript.
|
||||
|
||||
:param filepath: Filepath or file object to audio file.
|
||||
:type filepath: basestring|file
|
||||
:param transcript: Transcript text for the speech.
|
||||
:type transript: basestring
|
||||
:return: Speech segment instance.
|
||||
:rtype: SpeechSegment
|
||||
"""
|
||||
audio = AudioSegment.from_file(filepath)
|
||||
return cls(audio.samples, audio.sample_rate, transcript)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, bytes, transcript):
|
||||
"""Create speech segment from a byte string and corresponding
|
||||
transcript.
|
||||
|
||||
:param bytes: Byte string containing audio samples.
|
||||
:type bytes: str
|
||||
:param transcript: Transcript text for the speech.
|
||||
:type transript: basestring
|
||||
:return: Speech segment instance.
|
||||
:rtype: Speech Segment
|
||||
"""
|
||||
audio = AudioSegment.from_bytes(bytes)
|
||||
return cls(audio.samples, audio.sample_rate, transcript)
|
||||
|
||||
@classmethod
|
||||
def concatenate(cls, *segments):
|
||||
"""Concatenate an arbitrary number of speech segments together, both
|
||||
audio and transcript will be concatenated.
|
||||
|
||||
:param *segments: Input speech segments to be concatenated.
|
||||
:type *segments: tuple of SpeechSegment
|
||||
:return: Speech segment instance.
|
||||
:rtype: SpeechSegment
|
||||
:raises ValueError: If the number of segments is zero, or if the
|
||||
sample_rate of any two segments does not match.
|
||||
:raises TypeError: If any segment is not SpeechSegment instance.
|
||||
"""
|
||||
if len(segments) == 0:
|
||||
raise ValueError("No speech segments are given to concatenate.")
|
||||
sample_rate = segments[0]._sample_rate
|
||||
transcripts = ""
|
||||
for seg in segments:
|
||||
if sample_rate != seg._sample_rate:
|
||||
raise ValueError("Can't concatenate segments with "
|
||||
"different sample rates")
|
||||
if type(seg) is not cls:
|
||||
raise TypeError("Only speech segments of the same type "
|
||||
"instance can be concatenated.")
|
||||
transcripts += seg._transcript
|
||||
samples = np.concatenate([seg.samples for seg in segments])
|
||||
return cls(samples, sample_rate, transcripts)
|
||||
|
||||
@classmethod
|
||||
def slice_from_file(cls, filepath, transcript, start=None, end=None):
|
||||
"""Loads a small section of an speech without having to load
|
||||
the entire file into the memory which can be incredibly wasteful.
|
||||
|
||||
:param filepath: Filepath or file object to audio file.
|
||||
:type filepath: basestring|file
|
||||
:param start: Start time in seconds. If start is negative, it wraps
|
||||
around from the end. If not provided, this function
|
||||
reads from the very beginning.
|
||||
:type start: float
|
||||
:param end: End time in seconds. If end is negative, it wraps around
|
||||
from the end. If not provided, the default behvaior is
|
||||
to read to the end of the file.
|
||||
:type end: float
|
||||
:param transcript: Transcript text for the speech. if not provided,
|
||||
the defaults is an empty string.
|
||||
:type transript: basestring
|
||||
:return: SpeechSegment instance of the specified slice of the input
|
||||
speech file.
|
||||
:rtype: SpeechSegment
|
||||
"""
|
||||
audio = AudioSegment.slice_from_file(filepath, start, end)
|
||||
return cls(audio.samples, audio.sample_rate, transcript)
|
||||
|
||||
@classmethod
|
||||
def make_silence(cls, duration, sample_rate):
|
||||
"""Creates a silent speech segment of the given duration and
|
||||
sample rate, transcript will be an empty string.
|
||||
|
||||
:param duration: Length of silence in seconds.
|
||||
:type duration: float
|
||||
:param sample_rate: Sample rate.
|
||||
:type sample_rate: float
|
||||
:return: Silence of the given duration.
|
||||
:rtype: SpeechSegment
|
||||
"""
|
||||
audio = AudioSegment.make_silence(duration, sample_rate)
|
||||
return cls(audio.samples, audio.sample_rate, "")
|
||||
|
||||
@property
|
||||
def transcript(self):
|
||||
"""Return the transcript text.
|
||||
|
||||
:return: Transcript text for the speech.
|
||||
:rtype: basestring
|
||||
"""
|
||||
return self._transcript
|
@ -0,0 +1,214 @@
|
||||
"""Contains data helper functions."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import codecs
|
||||
import os
|
||||
import tarfile
|
||||
import time
|
||||
from Queue import Queue
|
||||
from threading import Thread
|
||||
from multiprocessing import Process, Manager, Value
|
||||
from paddle.v2.dataset.common import md5file
|
||||
|
||||
|
||||
def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0):
|
||||
"""Load and parse manifest file.
|
||||
|
||||
Instances with durations outside [min_duration, max_duration] will be
|
||||
filtered out.
|
||||
|
||||
:param manifest_path: Manifest file to load and parse.
|
||||
:type manifest_path: basestring
|
||||
:param max_duration: Maximal duration in seconds for instance filter.
|
||||
:type max_duration: float
|
||||
:param min_duration: Minimal duration in seconds for instance filter.
|
||||
:type min_duration: float
|
||||
:return: Manifest parsing results. List of dict.
|
||||
:rtype: list
|
||||
:raises IOError: If failed to parse the manifest.
|
||||
"""
|
||||
manifest = []
|
||||
for json_line in codecs.open(manifest_path, 'r', 'utf-8'):
|
||||
try:
|
||||
json_data = json.loads(json_line)
|
||||
except Exception as e:
|
||||
raise IOError("Error reading manifest: %s" % str(e))
|
||||
if (json_data["duration"] <= max_duration and
|
||||
json_data["duration"] >= min_duration):
|
||||
manifest.append(json_data)
|
||||
return manifest
|
||||
|
||||
|
||||
def getfile_insensitive(path):
|
||||
"""Get the actual file path when given insensitive filename."""
|
||||
directory, filename = os.path.split(path)
|
||||
directory, filename = (directory or '.'), filename.lower()
|
||||
for f in os.listdir(directory):
|
||||
newpath = os.path.join(directory, f)
|
||||
if os.path.isfile(newpath) and f.lower() == filename:
|
||||
return newpath
|
||||
|
||||
|
||||
def download_multi(url, target_dir, extra_args):
|
||||
"""Download multiple files from url to target_dir."""
|
||||
if not os.path.exists(target_dir): os.makedirs(target_dir)
|
||||
print("Downloading %s ..." % url)
|
||||
ret_code = os.system("wget -c " + url + ' ' + extra_args + " -P " +
|
||||
target_dir)
|
||||
return ret_code
|
||||
|
||||
|
||||
def download(url, md5sum, target_dir):
|
||||
"""Download file from url to target_dir, and check md5sum."""
|
||||
if not os.path.exists(target_dir): os.makedirs(target_dir)
|
||||
filepath = os.path.join(target_dir, url.split("/")[-1])
|
||||
if not (os.path.exists(filepath) and md5file(filepath) == md5sum):
|
||||
print("Downloading %s ..." % url)
|
||||
os.system("wget -c " + url + " -P " + target_dir)
|
||||
print("\nMD5 Chesksum %s ..." % filepath)
|
||||
if not md5file(filepath) == md5sum:
|
||||
raise RuntimeError("MD5 checksum failed.")
|
||||
else:
|
||||
print("File exists, skip downloading. (%s)" % filepath)
|
||||
return filepath
|
||||
|
||||
|
||||
def unpack(filepath, target_dir, rm_tar=False):
|
||||
"""Unpack the file to the target_dir."""
|
||||
print("Unpacking %s ..." % filepath)
|
||||
tar = tarfile.open(filepath)
|
||||
tar.extractall(target_dir)
|
||||
tar.close()
|
||||
if rm_tar == True:
|
||||
os.remove(filepath)
|
||||
|
||||
|
||||
class XmapEndSignal():
|
||||
pass
|
||||
|
||||
|
||||
def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
|
||||
"""A multiprocessing pipeline wrapper for the data reader.
|
||||
|
||||
:param mapper: Function to map sample.
|
||||
:type mapper: callable
|
||||
:param reader: Given data reader.
|
||||
:type reader: callable
|
||||
:param process_num: Number of processes in the pipeline
|
||||
:type process_num: int
|
||||
:param buffer_size: Maximal buffer size.
|
||||
:type buffer_size: int
|
||||
:return: The wrappered reader and cleanup callback
|
||||
:rtype: tuple
|
||||
"""
|
||||
end_flag = XmapEndSignal()
|
||||
|
||||
read_workers = []
|
||||
handle_workers = []
|
||||
flush_workers = []
|
||||
|
||||
read_exit_flag = Value('i', 0)
|
||||
handle_exit_flag = Value('i', 0)
|
||||
flush_exit_flag = Value('i', 0)
|
||||
|
||||
# define a worker to read samples from reader to in_queue with order flag
|
||||
def order_read_worker(reader, in_queue):
|
||||
for order_id, sample in enumerate(reader()):
|
||||
if read_exit_flag.value == 1: break
|
||||
in_queue.put((order_id, sample))
|
||||
in_queue.put(end_flag)
|
||||
# the reading worker should not exit until all handling work exited
|
||||
while handle_exit_flag.value == 0 or read_exit_flag.value == 0:
|
||||
time.sleep(0.001)
|
||||
|
||||
# define a worker to handle samples from in_queue by mapper and put results
|
||||
# to out_queue with order
|
||||
def order_handle_worker(in_queue, out_queue, mapper, out_order):
|
||||
ins = in_queue.get()
|
||||
while not isinstance(ins, XmapEndSignal):
|
||||
if handle_exit_flag.value == 1: break
|
||||
order_id, sample = ins
|
||||
result = mapper(sample)
|
||||
while order_id != out_order[0]:
|
||||
time.sleep(0.001)
|
||||
out_queue.put(result)
|
||||
out_order[0] += 1
|
||||
ins = in_queue.get()
|
||||
in_queue.put(end_flag)
|
||||
out_queue.put(end_flag)
|
||||
# wait for exit of flushing worker
|
||||
while flush_exit_flag.value == 0 or handle_exit_flag.value == 0:
|
||||
time.sleep(0.001)
|
||||
read_exit_flag.value = 1
|
||||
handle_exit_flag.value = 1
|
||||
|
||||
# define a thread worker to flush samples from Manager.Queue to Queue
|
||||
# for acceleration
|
||||
def flush_worker(in_queue, out_queue):
|
||||
finish = 0
|
||||
while finish < process_num and flush_exit_flag.value == 0:
|
||||
sample = in_queue.get()
|
||||
if isinstance(sample, XmapEndSignal):
|
||||
finish += 1
|
||||
else:
|
||||
out_queue.put(sample)
|
||||
out_queue.put(end_flag)
|
||||
handle_exit_flag.value = 1
|
||||
flush_exit_flag.value = 1
|
||||
|
||||
def cleanup():
|
||||
# first exit flushing workers
|
||||
flush_exit_flag.value = 1
|
||||
for w in flush_workers:
|
||||
w.join()
|
||||
# next exit handling workers
|
||||
handle_exit_flag.value = 1
|
||||
for w in handle_workers:
|
||||
w.join()
|
||||
# last exit reading workers
|
||||
read_exit_flag.value = 1
|
||||
for w in read_workers:
|
||||
w.join()
|
||||
|
||||
def xreader():
|
||||
# prepare shared memory
|
||||
manager = Manager()
|
||||
in_queue = manager.Queue(buffer_size)
|
||||
out_queue = manager.Queue(buffer_size)
|
||||
out_order = manager.list([0])
|
||||
|
||||
# start a read worker in a process
|
||||
target = order_read_worker
|
||||
p = Process(target=target, args=(reader, in_queue))
|
||||
p.daemon = True
|
||||
p.start()
|
||||
read_workers.append(p)
|
||||
|
||||
# start handle_workers with multiple processes
|
||||
target = order_handle_worker
|
||||
args = (in_queue, out_queue, mapper, out_order)
|
||||
workers = [
|
||||
Process(target=target, args=args) for _ in xrange(process_num)
|
||||
]
|
||||
for w in workers:
|
||||
w.daemon = True
|
||||
w.start()
|
||||
handle_workers.append(w)
|
||||
|
||||
# start a thread to read data from slow Manager.Queue
|
||||
flush_queue = Queue(buffer_size)
|
||||
t = Thread(target=flush_worker, args=(out_queue, flush_queue))
|
||||
t.daemon = True
|
||||
t.start()
|
||||
flush_workers.append(t)
|
||||
|
||||
# get results
|
||||
sample = flush_queue.get()
|
||||
while not isinstance(sample, XmapEndSignal):
|
||||
yield sample
|
||||
sample = flush_queue.get()
|
||||
|
||||
return xreader, cleanup
|
@ -0,0 +1,238 @@
|
||||
"""Contains various CTC decoders."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from itertools import groupby
|
||||
import numpy as np
|
||||
from math import log
|
||||
import multiprocessing
|
||||
|
||||
|
||||
def ctc_greedy_decoder(probs_seq, vocabulary):
|
||||
"""CTC greedy (best path) decoder.
|
||||
|
||||
Path consisting of the most probable tokens are further post-processed to
|
||||
remove consecutive repetitions and all blanks.
|
||||
|
||||
:param probs_seq: 2-D list of probabilities over the vocabulary for each
|
||||
character. Each element is a list of float probabilities
|
||||
for one character.
|
||||
:type probs_seq: list
|
||||
:param vocabulary: Vocabulary list.
|
||||
:type vocabulary: list
|
||||
:return: Decoding result string.
|
||||
:rtype: baseline
|
||||
"""
|
||||
# dimension verification
|
||||
for probs in probs_seq:
|
||||
if not len(probs) == len(vocabulary) + 1:
|
||||
raise ValueError("probs_seq dimension mismatchedd with vocabulary")
|
||||
# argmax to get the best index for each time step
|
||||
max_index_list = list(np.array(probs_seq).argmax(axis=1))
|
||||
# remove consecutive duplicate indexes
|
||||
index_list = [index_group[0] for index_group in groupby(max_index_list)]
|
||||
# remove blank indexes
|
||||
blank_index = len(vocabulary)
|
||||
index_list = [index for index in index_list if index != blank_index]
|
||||
# convert index list to string
|
||||
return ''.join([vocabulary[index] for index in index_list])
|
||||
|
||||
|
||||
def ctc_beam_search_decoder(probs_seq,
|
||||
beam_size,
|
||||
vocabulary,
|
||||
cutoff_prob=1.0,
|
||||
cutoff_top_n=40,
|
||||
ext_scoring_func=None,
|
||||
nproc=False):
|
||||
"""CTC Beam search decoder.
|
||||
|
||||
It utilizes beam search to approximately select top best decoding
|
||||
labels and returning results in the descending order.
|
||||
The implementation is based on Prefix Beam Search
|
||||
(https://arxiv.org/abs/1408.2873), and the unclear part is
|
||||
redesigned. Two important modifications: 1) in the iterative computation
|
||||
of probabilities, the assignment operation is changed to accumulation for
|
||||
one prefix may comes from different paths; 2) the if condition "if l^+ not
|
||||
in A_prev then" after probabilities' computation is deprecated for it is
|
||||
hard to understand and seems unnecessary.
|
||||
|
||||
:param probs_seq: 2-D list of probability distributions over each time
|
||||
step, with each element being a list of normalized
|
||||
probabilities over vocabulary and blank.
|
||||
:type probs_seq: 2-D list
|
||||
:param beam_size: Width for beam search.
|
||||
:type beam_size: int
|
||||
:param vocabulary: Vocabulary list.
|
||||
:type vocabulary: list
|
||||
:param cutoff_prob: Cutoff probability in pruning,
|
||||
default 1.0, no pruning.
|
||||
:type cutoff_prob: float
|
||||
:param ext_scoring_func: External scoring function for
|
||||
partially decoded sentence, e.g. word count
|
||||
or language model.
|
||||
:type external_scoring_func: callable
|
||||
:param nproc: Whether the decoder used in multiprocesses.
|
||||
:type nproc: bool
|
||||
:return: List of tuples of log probability and sentence as decoding
|
||||
results, in descending order of the probability.
|
||||
:rtype: list
|
||||
"""
|
||||
# dimension check
|
||||
for prob_list in probs_seq:
|
||||
if not len(prob_list) == len(vocabulary) + 1:
|
||||
raise ValueError("The shape of prob_seq does not match with the "
|
||||
"shape of the vocabulary.")
|
||||
|
||||
# blank_id assign
|
||||
blank_id = len(vocabulary)
|
||||
|
||||
# If the decoder called in the multiprocesses, then use the global scorer
|
||||
# instantiated in ctc_beam_search_decoder_batch().
|
||||
if nproc is True:
|
||||
global ext_nproc_scorer
|
||||
ext_scoring_func = ext_nproc_scorer
|
||||
|
||||
## initialize
|
||||
# prefix_set_prev: the set containing selected prefixes
|
||||
# probs_b_prev: prefixes' probability ending with blank in previous step
|
||||
# probs_nb_prev: prefixes' probability ending with non-blank in previous step
|
||||
prefix_set_prev = {'\t': 1.0}
|
||||
probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0}
|
||||
|
||||
## extend prefix in loop
|
||||
for time_step in xrange(len(probs_seq)):
|
||||
# prefix_set_next: the set containing candidate prefixes
|
||||
# probs_b_cur: prefixes' probability ending with blank in current step
|
||||
# probs_nb_cur: prefixes' probability ending with non-blank in current step
|
||||
prefix_set_next, probs_b_cur, probs_nb_cur = {}, {}, {}
|
||||
|
||||
prob_idx = list(enumerate(probs_seq[time_step]))
|
||||
cutoff_len = len(prob_idx)
|
||||
#If pruning is enabled
|
||||
if cutoff_prob < 1.0 or cutoff_top_n < cutoff_len:
|
||||
prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
|
||||
cutoff_len, cum_prob = 0, 0.0
|
||||
for i in xrange(len(prob_idx)):
|
||||
cum_prob += prob_idx[i][1]
|
||||
cutoff_len += 1
|
||||
if cum_prob >= cutoff_prob:
|
||||
break
|
||||
cutoff_len = min(cutoff_len, cutoff_top_n)
|
||||
prob_idx = prob_idx[0:cutoff_len]
|
||||
|
||||
for l in prefix_set_prev:
|
||||
if not prefix_set_next.has_key(l):
|
||||
probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0
|
||||
|
||||
# extend prefix by travering prob_idx
|
||||
for index in xrange(cutoff_len):
|
||||
c, prob_c = prob_idx[index][0], prob_idx[index][1]
|
||||
|
||||
if c == blank_id:
|
||||
probs_b_cur[l] += prob_c * (
|
||||
probs_b_prev[l] + probs_nb_prev[l])
|
||||
else:
|
||||
last_char = l[-1]
|
||||
new_char = vocabulary[c]
|
||||
l_plus = l + new_char
|
||||
if not prefix_set_next.has_key(l_plus):
|
||||
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0
|
||||
|
||||
if new_char == last_char:
|
||||
probs_nb_cur[l_plus] += prob_c * probs_b_prev[l]
|
||||
probs_nb_cur[l] += prob_c * probs_nb_prev[l]
|
||||
elif new_char == ' ':
|
||||
if (ext_scoring_func is None) or (len(l) == 1):
|
||||
score = 1.0
|
||||
else:
|
||||
prefix = l[1:]
|
||||
score = ext_scoring_func(prefix)
|
||||
probs_nb_cur[l_plus] += score * prob_c * (
|
||||
probs_b_prev[l] + probs_nb_prev[l])
|
||||
else:
|
||||
probs_nb_cur[l_plus] += prob_c * (
|
||||
probs_b_prev[l] + probs_nb_prev[l])
|
||||
# add l_plus into prefix_set_next
|
||||
prefix_set_next[l_plus] = probs_nb_cur[
|
||||
l_plus] + probs_b_cur[l_plus]
|
||||
# add l into prefix_set_next
|
||||
prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
|
||||
# update probs
|
||||
probs_b_prev, probs_nb_prev = probs_b_cur, probs_nb_cur
|
||||
|
||||
## store top beam_size prefixes
|
||||
prefix_set_prev = sorted(
|
||||
prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True)
|
||||
if beam_size < len(prefix_set_prev):
|
||||
prefix_set_prev = prefix_set_prev[:beam_size]
|
||||
prefix_set_prev = dict(prefix_set_prev)
|
||||
|
||||
beam_result = []
|
||||
for seq, prob in prefix_set_prev.items():
|
||||
if prob > 0.0 and len(seq) > 1:
|
||||
result = seq[1:]
|
||||
# score last word by external scorer
|
||||
if (ext_scoring_func is not None) and (result[-1] != ' '):
|
||||
prob = prob * ext_scoring_func(result)
|
||||
log_prob = log(prob)
|
||||
beam_result.append((log_prob, result))
|
||||
else:
|
||||
beam_result.append((float('-inf'), ''))
|
||||
|
||||
## output top beam_size decoding results
|
||||
beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
|
||||
return beam_result
|
||||
|
||||
|
||||
def ctc_beam_search_decoder_batch(probs_split,
|
||||
beam_size,
|
||||
vocabulary,
|
||||
num_processes,
|
||||
cutoff_prob=1.0,
|
||||
cutoff_top_n=40,
|
||||
ext_scoring_func=None):
|
||||
"""CTC beam search decoder using multiple processes.
|
||||
|
||||
:param probs_seq: 3-D list with each element as an instance of 2-D list
|
||||
of probabilities used by ctc_beam_search_decoder().
|
||||
:type probs_seq: 3-D list
|
||||
:param beam_size: Width for beam search.
|
||||
:type beam_size: int
|
||||
:param vocabulary: Vocabulary list.
|
||||
:type vocabulary: list
|
||||
:param num_processes: Number of parallel processes.
|
||||
:type num_processes: int
|
||||
:param cutoff_prob: Cutoff probability in pruning,
|
||||
default 1.0, no pruning.
|
||||
:type cutoff_prob: float
|
||||
:param num_processes: Number of parallel processes.
|
||||
:type num_processes: int
|
||||
:param ext_scoring_func: External scoring function for
|
||||
partially decoded sentence, e.g. word count
|
||||
or language model.
|
||||
:type external_scoring_function: callable
|
||||
:return: List of tuples of log probability and sentence as decoding
|
||||
results, in descending order of the probability.
|
||||
:rtype: list
|
||||
"""
|
||||
if not num_processes > 0:
|
||||
raise ValueError("Number of processes must be positive!")
|
||||
|
||||
# use global variable to pass the externnal scorer to beam search decoder
|
||||
global ext_nproc_scorer
|
||||
ext_nproc_scorer = ext_scoring_func
|
||||
nproc = True
|
||||
|
||||
pool = multiprocessing.Pool(processes=num_processes)
|
||||
results = []
|
||||
for i, probs_list in enumerate(probs_split):
|
||||
args = (probs_list, beam_size, vocabulary, cutoff_prob, cutoff_top_n,
|
||||
None, nproc)
|
||||
results.append(pool.apply_async(ctc_beam_search_decoder, args))
|
||||
|
||||
pool.close()
|
||||
pool.join()
|
||||
beam_search_results = [result.get() for result in results]
|
||||
return beam_search_results
|
@ -0,0 +1,68 @@
|
||||
"""External Scorer for Beam Search Decoder."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import kenlm
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Scorer(object):
|
||||
"""External scorer to evaluate a prefix or whole sentence in
|
||||
beam search decoding, including the score from n-gram language
|
||||
model and word count.
|
||||
|
||||
:param alpha: Parameter associated with language model. Don't use
|
||||
language model when alpha = 0.
|
||||
:type alpha: float
|
||||
:param beta: Parameter associated with word count. Don't use word
|
||||
count when beta = 0.
|
||||
:type beta: float
|
||||
:model_path: Path to load language model.
|
||||
:type model_path: basestring
|
||||
"""
|
||||
|
||||
def __init__(self, alpha, beta, model_path):
|
||||
self._alpha = alpha
|
||||
self._beta = beta
|
||||
if not os.path.isfile(model_path):
|
||||
raise IOError("Invaid language model path: %s" % model_path)
|
||||
self._language_model = kenlm.LanguageModel(model_path)
|
||||
|
||||
# n-gram language model scoring
|
||||
def _language_model_score(self, sentence):
|
||||
#log10 prob of last word
|
||||
log_cond_prob = list(
|
||||
self._language_model.full_scores(sentence, eos=False))[-1][0]
|
||||
return np.power(10, log_cond_prob)
|
||||
|
||||
# word insertion term
|
||||
def _word_count(self, sentence):
|
||||
words = sentence.strip().split(' ')
|
||||
return len(words)
|
||||
|
||||
# reset alpha and beta
|
||||
def reset_params(self, alpha, beta):
|
||||
self._alpha = alpha
|
||||
self._beta = beta
|
||||
|
||||
# execute evaluation
|
||||
def __call__(self, sentence, log=False):
|
||||
"""Evaluation function, gathering all the different scores
|
||||
and return the final one.
|
||||
|
||||
:param sentence: The input sentence for evalutation
|
||||
:type sentence: basestring
|
||||
:param log: Whether return the score in log representation.
|
||||
:type log: bool
|
||||
:return: Evaluation score, in the decimal or log.
|
||||
:rtype: float
|
||||
"""
|
||||
lm = self._language_model_score(sentence)
|
||||
word_cnt = self._word_count(sentence)
|
||||
if log == False:
|
||||
score = np.power(lm, self._alpha) * np.power(word_cnt, self._beta)
|
||||
else:
|
||||
score = self._alpha * np.log(lm) + self._beta * np.log(word_cnt)
|
||||
return score
|
@ -0,0 +1,19 @@
|
||||
"""Set up paths for DS2"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os.path
|
||||
import sys
|
||||
|
||||
|
||||
def add_path(path):
|
||||
if path not in sys.path:
|
||||
sys.path.insert(0, path)
|
||||
|
||||
|
||||
this_dir = os.path.dirname(__file__)
|
||||
|
||||
# Add project path to PYTHONPATH
|
||||
proj_path = os.path.join(this_dir, '..')
|
||||
add_path(proj_path)
|
@ -0,0 +1,222 @@
|
||||
#include "ctc_beam_search_decoder.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
|
||||
#include "ThreadPool.h"
|
||||
#include "fst/fstlib.h"
|
||||
|
||||
#include "decoder_utils.h"
|
||||
#include "path_trie.h"
|
||||
|
||||
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
||||
|
||||
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
|
||||
const std::vector<std::vector<double>> &probs_seq,
|
||||
const std::vector<std::string> &vocabulary,
|
||||
size_t beam_size,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
Scorer *ext_scorer) {
|
||||
// dimension check
|
||||
size_t num_time_steps = probs_seq.size();
|
||||
for (size_t i = 0; i < num_time_steps; ++i) {
|
||||
VALID_CHECK_EQ(probs_seq[i].size(),
|
||||
vocabulary.size() + 1,
|
||||
"The shape of probs_seq does not match with "
|
||||
"the shape of the vocabulary");
|
||||
}
|
||||
|
||||
// assign blank id
|
||||
size_t blank_id = vocabulary.size();
|
||||
|
||||
// assign space id
|
||||
auto it = std::find(vocabulary.begin(), vocabulary.end(), " ");
|
||||
int space_id = it - vocabulary.begin();
|
||||
// if no space in vocabulary
|
||||
if ((size_t)space_id >= vocabulary.size()) {
|
||||
space_id = -2;
|
||||
}
|
||||
|
||||
// init prefixes' root
|
||||
PathTrie root;
|
||||
root.score = root.log_prob_b_prev = 0.0;
|
||||
std::vector<PathTrie *> prefixes;
|
||||
prefixes.push_back(&root);
|
||||
|
||||
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
||||
auto fst_dict = static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
|
||||
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
|
||||
root.set_dictionary(dict_ptr);
|
||||
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
|
||||
root.set_matcher(matcher);
|
||||
}
|
||||
|
||||
// prefix search over time
|
||||
for (size_t time_step = 0; time_step < num_time_steps; ++time_step) {
|
||||
auto &prob = probs_seq[time_step];
|
||||
|
||||
float min_cutoff = -NUM_FLT_INF;
|
||||
bool full_beam = false;
|
||||
if (ext_scorer != nullptr) {
|
||||
size_t num_prefixes = std::min(prefixes.size(), beam_size);
|
||||
std::sort(
|
||||
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
|
||||
min_cutoff = prefixes[num_prefixes - 1]->score +
|
||||
std::log(prob[blank_id]) - std::max(0.0, ext_scorer->beta);
|
||||
full_beam = (num_prefixes == beam_size);
|
||||
}
|
||||
|
||||
std::vector<std::pair<size_t, float>> log_prob_idx =
|
||||
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
|
||||
// loop over chars
|
||||
for (size_t index = 0; index < log_prob_idx.size(); index++) {
|
||||
auto c = log_prob_idx[index].first;
|
||||
auto log_prob_c = log_prob_idx[index].second;
|
||||
|
||||
for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) {
|
||||
auto prefix = prefixes[i];
|
||||
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
|
||||
break;
|
||||
}
|
||||
// blank
|
||||
if (c == blank_id) {
|
||||
prefix->log_prob_b_cur =
|
||||
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
|
||||
continue;
|
||||
}
|
||||
// repeated character
|
||||
if (c == prefix->character) {
|
||||
prefix->log_prob_nb_cur = log_sum_exp(
|
||||
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
|
||||
}
|
||||
// get new prefix
|
||||
auto prefix_new = prefix->get_path_trie(c);
|
||||
|
||||
if (prefix_new != nullptr) {
|
||||
float log_p = -NUM_FLT_INF;
|
||||
|
||||
if (c == prefix->character &&
|
||||
prefix->log_prob_b_prev > -NUM_FLT_INF) {
|
||||
log_p = log_prob_c + prefix->log_prob_b_prev;
|
||||
} else if (c != prefix->character) {
|
||||
log_p = log_prob_c + prefix->score;
|
||||
}
|
||||
|
||||
// language model scoring
|
||||
if (ext_scorer != nullptr &&
|
||||
(c == space_id || ext_scorer->is_character_based())) {
|
||||
PathTrie *prefix_to_score = nullptr;
|
||||
// skip scoring the space
|
||||
if (ext_scorer->is_character_based()) {
|
||||
prefix_to_score = prefix_new;
|
||||
} else {
|
||||
prefix_to_score = prefix;
|
||||
}
|
||||
|
||||
float score = 0.0;
|
||||
std::vector<std::string> ngram;
|
||||
ngram = ext_scorer->make_ngram(prefix_to_score);
|
||||
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
|
||||
log_p += score;
|
||||
log_p += ext_scorer->beta;
|
||||
}
|
||||
prefix_new->log_prob_nb_cur =
|
||||
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
|
||||
}
|
||||
} // end of loop over prefix
|
||||
} // end of loop over vocabulary
|
||||
|
||||
|
||||
prefixes.clear();
|
||||
// update log probs
|
||||
root.iterate_to_vec(prefixes);
|
||||
|
||||
// only preserve top beam_size prefixes
|
||||
if (prefixes.size() >= beam_size) {
|
||||
std::nth_element(prefixes.begin(),
|
||||
prefixes.begin() + beam_size,
|
||||
prefixes.end(),
|
||||
prefix_compare);
|
||||
for (size_t i = beam_size; i < prefixes.size(); ++i) {
|
||||
prefixes[i]->remove();
|
||||
}
|
||||
}
|
||||
} // end of loop over time
|
||||
|
||||
// score the last word of each prefix that doesn't end with space
|
||||
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
||||
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
||||
auto prefix = prefixes[i];
|
||||
if (!prefix->is_empty() && prefix->character != space_id) {
|
||||
float score = 0.0;
|
||||
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
|
||||
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
|
||||
score += ext_scorer->beta;
|
||||
prefix->score += score;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t num_prefixes = std::min(prefixes.size(), beam_size);
|
||||
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
|
||||
|
||||
// compute aproximate ctc score as the return score, without affecting the
|
||||
// return order of decoding result. To delete when decoder gets stable.
|
||||
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
||||
double approx_ctc = prefixes[i]->score;
|
||||
if (ext_scorer != nullptr) {
|
||||
std::vector<int> output;
|
||||
prefixes[i]->get_path_vec(output);
|
||||
auto prefix_length = output.size();
|
||||
auto words = ext_scorer->split_labels(output);
|
||||
// remove word insert
|
||||
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
|
||||
// remove language model weight:
|
||||
approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
|
||||
}
|
||||
prefixes[i]->approx_ctc = approx_ctc;
|
||||
}
|
||||
|
||||
return get_beam_search_result(prefixes, vocabulary, beam_size);
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::vector<std::pair<double, std::string>>>
|
||||
ctc_beam_search_decoder_batch(
|
||||
const std::vector<std::vector<std::vector<double>>> &probs_split,
|
||||
const std::vector<std::string> &vocabulary,
|
||||
size_t beam_size,
|
||||
size_t num_processes,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n,
|
||||
Scorer *ext_scorer) {
|
||||
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
|
||||
// thread pool
|
||||
ThreadPool pool(num_processes);
|
||||
// number of samples
|
||||
size_t batch_size = probs_split.size();
|
||||
|
||||
// enqueue the tasks of decoding
|
||||
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
res.emplace_back(pool.enqueue(ctc_beam_search_decoder,
|
||||
probs_split[i],
|
||||
vocabulary,
|
||||
beam_size,
|
||||
cutoff_prob,
|
||||
cutoff_top_n,
|
||||
ext_scorer));
|
||||
}
|
||||
|
||||
// get decoding results
|
||||
std::vector<std::vector<std::pair<double, std::string>>> batch_results;
|
||||
for (size_t i = 0; i < batch_size; ++i) {
|
||||
batch_results.emplace_back(res[i].get());
|
||||
}
|
||||
return batch_results;
|
||||
}
|
@ -0,0 +1,61 @@
|
||||
#ifndef CTC_BEAM_SEARCH_DECODER_H_
|
||||
#define CTC_BEAM_SEARCH_DECODER_H_
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "scorer.h"
|
||||
|
||||
/* CTC Beam Search Decoder
|
||||
|
||||
* Parameters:
|
||||
* probs_seq: 2-D vector that each element is a vector of probabilities
|
||||
* over vocabulary of one time step.
|
||||
* vocabulary: A vector of vocabulary.
|
||||
* beam_size: The width of beam search.
|
||||
* cutoff_prob: Cutoff probability for pruning.
|
||||
* cutoff_top_n: Cutoff number for pruning.
|
||||
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
||||
* n-gram language model scoring and word insertion term.
|
||||
* Default null, decoding the input sample without scorer.
|
||||
* Return:
|
||||
* A vector that each element is a pair of score and decoding result,
|
||||
* in desending order.
|
||||
*/
|
||||
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
|
||||
const std::vector<std::vector<double>> &probs_seq,
|
||||
const std::vector<std::string> &vocabulary,
|
||||
size_t beam_size,
|
||||
double cutoff_prob = 1.0,
|
||||
size_t cutoff_top_n = 40,
|
||||
Scorer *ext_scorer = nullptr);
|
||||
|
||||
/* CTC Beam Search Decoder for batch data
|
||||
|
||||
* Parameters:
|
||||
* probs_seq: 3-D vector that each element is a 2-D vector that can be used
|
||||
* by ctc_beam_search_decoder().
|
||||
* vocabulary: A vector of vocabulary.
|
||||
* beam_size: The width of beam search.
|
||||
* num_processes: Number of threads for beam search.
|
||||
* cutoff_prob: Cutoff probability for pruning.
|
||||
* cutoff_top_n: Cutoff number for pruning.
|
||||
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
||||
* n-gram language model scoring and word insertion term.
|
||||
* Default null, decoding the input sample without scorer.
|
||||
* Return:
|
||||
* A 2-D vector that each element is a vector of beam search decoding
|
||||
* result for one audio sample.
|
||||
*/
|
||||
std::vector<std::vector<std::pair<double, std::string>>>
|
||||
ctc_beam_search_decoder_batch(
|
||||
const std::vector<std::vector<std::vector<double>>> &probs_split,
|
||||
const std::vector<std::string> &vocabulary,
|
||||
size_t beam_size,
|
||||
size_t num_processes,
|
||||
double cutoff_prob = 1.0,
|
||||
size_t cutoff_top_n = 40,
|
||||
Scorer *ext_scorer = nullptr);
|
||||
|
||||
#endif // CTC_BEAM_SEARCH_DECODER_H_
|
@ -0,0 +1,45 @@
|
||||
#include "ctc_greedy_decoder.h"
|
||||
#include "decoder_utils.h"
|
||||
|
||||
std::string ctc_greedy_decoder(
|
||||
const std::vector<std::vector<double>> &probs_seq,
|
||||
const std::vector<std::string> &vocabulary) {
|
||||
// dimension check
|
||||
size_t num_time_steps = probs_seq.size();
|
||||
for (size_t i = 0; i < num_time_steps; ++i) {
|
||||
VALID_CHECK_EQ(probs_seq[i].size(),
|
||||
vocabulary.size() + 1,
|
||||
"The shape of probs_seq does not match with "
|
||||
"the shape of the vocabulary");
|
||||
}
|
||||
|
||||
size_t blank_id = vocabulary.size();
|
||||
|
||||
std::vector<size_t> max_idx_vec(num_time_steps, 0);
|
||||
std::vector<size_t> idx_vec;
|
||||
for (size_t i = 0; i < num_time_steps; ++i) {
|
||||
double max_prob = 0.0;
|
||||
size_t max_idx = 0;
|
||||
const std::vector<double> &probs_step = probs_seq[i];
|
||||
for (size_t j = 0; j < probs_step.size(); ++j) {
|
||||
if (max_prob < probs_step[j]) {
|
||||
max_idx = j;
|
||||
max_prob = probs_step[j];
|
||||
}
|
||||
}
|
||||
// id with maximum probability in current time step
|
||||
max_idx_vec[i] = max_idx;
|
||||
// deduplicate
|
||||
if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) {
|
||||
idx_vec.push_back(max_idx_vec[i]);
|
||||
}
|
||||
}
|
||||
|
||||
std::string best_path_result;
|
||||
for (size_t i = 0; i < idx_vec.size(); ++i) {
|
||||
if (idx_vec[i] != blank_id) {
|
||||
best_path_result += vocabulary[idx_vec[i]];
|
||||
}
|
||||
}
|
||||
return best_path_result;
|
||||
}
|
@ -0,0 +1,20 @@
|
||||
#ifndef CTC_GREEDY_DECODER_H
|
||||
#define CTC_GREEDY_DECODER_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
/* CTC Greedy (Best Path) Decoder
|
||||
*
|
||||
* Parameters:
|
||||
* probs_seq: 2-D vector that each element is a vector of probabilities
|
||||
* over vocabulary of one time step.
|
||||
* vocabulary: A vector of vocabulary.
|
||||
* Return:
|
||||
* The decoding result in string
|
||||
*/
|
||||
std::string ctc_greedy_decoder(
|
||||
const std::vector<std::vector<double>>& probs_seq,
|
||||
const std::vector<std::string>& vocabulary);
|
||||
|
||||
#endif // CTC_GREEDY_DECODER_H
|
@ -0,0 +1,176 @@
|
||||
#include "decoder_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
|
||||
std::vector<std::pair<size_t, float>> get_pruned_log_probs(
|
||||
const std::vector<double> &prob_step,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n) {
|
||||
std::vector<std::pair<int, double>> prob_idx;
|
||||
for (size_t i = 0; i < prob_step.size(); ++i) {
|
||||
prob_idx.push_back(std::pair<int, double>(i, prob_step[i]));
|
||||
}
|
||||
// pruning of vacobulary
|
||||
size_t cutoff_len = prob_step.size();
|
||||
if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) {
|
||||
std::sort(
|
||||
prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>);
|
||||
if (cutoff_prob < 1.0) {
|
||||
double cum_prob = 0.0;
|
||||
cutoff_len = 0;
|
||||
for (size_t i = 0; i < prob_idx.size(); ++i) {
|
||||
cum_prob += prob_idx[i].second;
|
||||
cutoff_len += 1;
|
||||
if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n) break;
|
||||
}
|
||||
}
|
||||
prob_idx = std::vector<std::pair<int, double>>(
|
||||
prob_idx.begin(), prob_idx.begin() + cutoff_len);
|
||||
}
|
||||
std::vector<std::pair<size_t, float>> log_prob_idx;
|
||||
for (size_t i = 0; i < cutoff_len; ++i) {
|
||||
log_prob_idx.push_back(std::pair<int, float>(
|
||||
prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN)));
|
||||
}
|
||||
return log_prob_idx;
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::pair<double, std::string>> get_beam_search_result(
|
||||
const std::vector<PathTrie *> &prefixes,
|
||||
const std::vector<std::string> &vocabulary,
|
||||
size_t beam_size) {
|
||||
// allow for the post processing
|
||||
std::vector<PathTrie *> space_prefixes;
|
||||
if (space_prefixes.empty()) {
|
||||
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
||||
space_prefixes.push_back(prefixes[i]);
|
||||
}
|
||||
}
|
||||
|
||||
std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
|
||||
std::vector<std::pair<double, std::string>> output_vecs;
|
||||
for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) {
|
||||
std::vector<int> output;
|
||||
space_prefixes[i]->get_path_vec(output);
|
||||
// convert index to string
|
||||
std::string output_str;
|
||||
for (size_t j = 0; j < output.size(); j++) {
|
||||
output_str += vocabulary[output[j]];
|
||||
}
|
||||
std::pair<double, std::string> output_pair(-space_prefixes[i]->approx_ctc,
|
||||
output_str);
|
||||
output_vecs.emplace_back(output_pair);
|
||||
}
|
||||
|
||||
return output_vecs;
|
||||
}
|
||||
|
||||
size_t get_utf8_str_len(const std::string &str) {
|
||||
size_t str_len = 0;
|
||||
for (char c : str) {
|
||||
str_len += ((c & 0xc0) != 0x80);
|
||||
}
|
||||
return str_len;
|
||||
}
|
||||
|
||||
std::vector<std::string> split_utf8_str(const std::string &str) {
|
||||
std::vector<std::string> result;
|
||||
std::string out_str;
|
||||
|
||||
for (char c : str) {
|
||||
if ((c & 0xc0) != 0x80) // new UTF-8 character
|
||||
{
|
||||
if (!out_str.empty()) {
|
||||
result.push_back(out_str);
|
||||
out_str.clear();
|
||||
}
|
||||
}
|
||||
|
||||
out_str.append(1, c);
|
||||
}
|
||||
result.push_back(out_str);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<std::string> split_str(const std::string &s,
|
||||
const std::string &delim) {
|
||||
std::vector<std::string> result;
|
||||
std::size_t start = 0, delim_len = delim.size();
|
||||
while (true) {
|
||||
std::size_t end = s.find(delim, start);
|
||||
if (end == std::string::npos) {
|
||||
if (start < s.size()) {
|
||||
result.push_back(s.substr(start));
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (end > start) {
|
||||
result.push_back(s.substr(start, end - start));
|
||||
}
|
||||
start = end + delim_len;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool prefix_compare(const PathTrie *x, const PathTrie *y) {
|
||||
if (x->score == y->score) {
|
||||
if (x->character == y->character) {
|
||||
return false;
|
||||
} else {
|
||||
return (x->character < y->character);
|
||||
}
|
||||
} else {
|
||||
return x->score > y->score;
|
||||
}
|
||||
}
|
||||
|
||||
void add_word_to_fst(const std::vector<int> &word,
|
||||
fst::StdVectorFst *dictionary) {
|
||||
if (dictionary->NumStates() == 0) {
|
||||
fst::StdVectorFst::StateId start = dictionary->AddState();
|
||||
assert(start == 0);
|
||||
dictionary->SetStart(start);
|
||||
}
|
||||
fst::StdVectorFst::StateId src = dictionary->Start();
|
||||
fst::StdVectorFst::StateId dst;
|
||||
for (auto c : word) {
|
||||
dst = dictionary->AddState();
|
||||
dictionary->AddArc(src, fst::StdArc(c, c, 0, dst));
|
||||
src = dst;
|
||||
}
|
||||
dictionary->SetFinal(dst, fst::StdArc::Weight::One());
|
||||
}
|
||||
|
||||
bool add_word_to_dictionary(
|
||||
const std::string &word,
|
||||
const std::unordered_map<std::string, int> &char_map,
|
||||
bool add_space,
|
||||
int SPACE_ID,
|
||||
fst::StdVectorFst *dictionary) {
|
||||
auto characters = split_utf8_str(word);
|
||||
|
||||
std::vector<int> int_word;
|
||||
|
||||
for (auto &c : characters) {
|
||||
if (c == " ") {
|
||||
int_word.push_back(SPACE_ID);
|
||||
} else {
|
||||
auto int_c = char_map.find(c);
|
||||
if (int_c != char_map.end()) {
|
||||
int_word.push_back(int_c->second);
|
||||
} else {
|
||||
return false; // return without adding
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (add_space) {
|
||||
int_word.push_back(SPACE_ID);
|
||||
}
|
||||
|
||||
add_word_to_fst(int_word, dictionary);
|
||||
return true; // return with successful adding
|
||||
}
|
@ -0,0 +1,94 @@
|
||||
#ifndef DECODER_UTILS_H_
|
||||
#define DECODER_UTILS_H_
|
||||
|
||||
#include <utility>
|
||||
#include "fst/log.h"
|
||||
#include "path_trie.h"
|
||||
|
||||
const float NUM_FLT_INF = std::numeric_limits<float>::max();
|
||||
const float NUM_FLT_MIN = std::numeric_limits<float>::min();
|
||||
|
||||
// inline function for validation check
|
||||
inline void check(
|
||||
bool x, const char *expr, const char *file, int line, const char *err) {
|
||||
if (!x) {
|
||||
std::cout << "[" << file << ":" << line << "] ";
|
||||
LOG(FATAL) << "\"" << expr << "\" check failed. " << err;
|
||||
}
|
||||
}
|
||||
|
||||
#define VALID_CHECK(x, info) \
|
||||
check(static_cast<bool>(x), #x, __FILE__, __LINE__, info)
|
||||
#define VALID_CHECK_EQ(x, y, info) VALID_CHECK((x) == (y), info)
|
||||
#define VALID_CHECK_GT(x, y, info) VALID_CHECK((x) > (y), info)
|
||||
#define VALID_CHECK_LT(x, y, info) VALID_CHECK((x) < (y), info)
|
||||
|
||||
|
||||
// Function template for comparing two pairs
|
||||
template <typename T1, typename T2>
|
||||
bool pair_comp_first_rev(const std::pair<T1, T2> &a,
|
||||
const std::pair<T1, T2> &b) {
|
||||
return a.first > b.first;
|
||||
}
|
||||
|
||||
// Function template for comparing two pairs
|
||||
template <typename T1, typename T2>
|
||||
bool pair_comp_second_rev(const std::pair<T1, T2> &a,
|
||||
const std::pair<T1, T2> &b) {
|
||||
return a.second > b.second;
|
||||
}
|
||||
|
||||
// Return the sum of two probabilities in log scale
|
||||
template <typename T>
|
||||
T log_sum_exp(const T &x, const T &y) {
|
||||
static T num_min = -std::numeric_limits<T>::max();
|
||||
if (x <= num_min) return y;
|
||||
if (y <= num_min) return x;
|
||||
T xmax = std::max(x, y);
|
||||
return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax;
|
||||
}
|
||||
|
||||
// Get pruned probability vector for each time step's beam search
|
||||
std::vector<std::pair<size_t, float>> get_pruned_log_probs(
|
||||
const std::vector<double> &prob_step,
|
||||
double cutoff_prob,
|
||||
size_t cutoff_top_n);
|
||||
|
||||
// Get beam search result from prefixes in trie tree
|
||||
std::vector<std::pair<double, std::string>> get_beam_search_result(
|
||||
const std::vector<PathTrie *> &prefixes,
|
||||
const std::vector<std::string> &vocabulary,
|
||||
size_t beam_size);
|
||||
|
||||
// Functor for prefix comparsion
|
||||
bool prefix_compare(const PathTrie *x, const PathTrie *y);
|
||||
|
||||
/* Get length of utf8 encoding string
|
||||
* See: http://stackoverflow.com/a/4063229
|
||||
*/
|
||||
size_t get_utf8_str_len(const std::string &str);
|
||||
|
||||
/* Split a string into a list of strings on a given string
|
||||
* delimiter. NB: delimiters on beginning / end of string are
|
||||
* trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"].
|
||||
*/
|
||||
std::vector<std::string> split_str(const std::string &s,
|
||||
const std::string &delim);
|
||||
|
||||
/* Splits string into vector of strings representing
|
||||
* UTF-8 characters (not same as chars)
|
||||
*/
|
||||
std::vector<std::string> split_utf8_str(const std::string &str);
|
||||
|
||||
// Add a word in index to the dicionary of fst
|
||||
void add_word_to_fst(const std::vector<int> &word,
|
||||
fst::StdVectorFst *dictionary);
|
||||
|
||||
// Add a word in string to dictionary
|
||||
bool add_word_to_dictionary(
|
||||
const std::string &word,
|
||||
const std::unordered_map<std::string, int> &char_map,
|
||||
bool add_space,
|
||||
int SPACE_ID,
|
||||
fst::StdVectorFst *dictionary);
|
||||
#endif // DECODER_UTILS_H
|
@ -0,0 +1,33 @@
|
||||
%module swig_decoders
|
||||
%{
|
||||
#include "scorer.h"
|
||||
#include "ctc_greedy_decoder.h"
|
||||
#include "ctc_beam_search_decoder.h"
|
||||
#include "decoder_utils.h"
|
||||
%}
|
||||
|
||||
%include "std_vector.i"
|
||||
%include "std_pair.i"
|
||||
%include "std_string.i"
|
||||
%import "decoder_utils.h"
|
||||
|
||||
namespace std {
|
||||
%template(DoubleVector) std::vector<double>;
|
||||
%template(IntVector) std::vector<int>;
|
||||
%template(StringVector) std::vector<std::string>;
|
||||
%template(VectorOfStructVector) std::vector<std::vector<double> >;
|
||||
%template(FloatVector) std::vector<float>;
|
||||
%template(Pair) std::pair<float, std::string>;
|
||||
%template(PairFloatStringVector) std::vector<std::pair<float, std::string> >;
|
||||
%template(PairDoubleStringVector) std::vector<std::pair<double, std::string> >;
|
||||
%template(PairDoubleStringVector2) std::vector<std::vector<std::pair<double, std::string> > >;
|
||||
%template(DoubleVector3) std::vector<std::vector<std::vector<double> > >;
|
||||
}
|
||||
|
||||
%template(IntDoublePairCompSecondRev) pair_comp_second_rev<int, double>;
|
||||
%template(StringDoublePairCompSecondRev) pair_comp_second_rev<std::string, double>;
|
||||
%template(DoubleStringPairCompFirstRev) pair_comp_first_rev<double, std::string>;
|
||||
|
||||
%include "scorer.h"
|
||||
%include "ctc_greedy_decoder.h"
|
||||
%include "ctc_beam_search_decoder.h"
|
@ -0,0 +1,148 @@
|
||||
#include "path_trie.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "decoder_utils.h"
|
||||
|
||||
PathTrie::PathTrie() {
|
||||
log_prob_b_prev = -NUM_FLT_INF;
|
||||
log_prob_nb_prev = -NUM_FLT_INF;
|
||||
log_prob_b_cur = -NUM_FLT_INF;
|
||||
log_prob_nb_cur = -NUM_FLT_INF;
|
||||
score = -NUM_FLT_INF;
|
||||
|
||||
ROOT_ = -1;
|
||||
character = ROOT_;
|
||||
exists_ = true;
|
||||
parent = nullptr;
|
||||
|
||||
dictionary_ = nullptr;
|
||||
dictionary_state_ = 0;
|
||||
has_dictionary_ = false;
|
||||
|
||||
matcher_ = nullptr;
|
||||
}
|
||||
|
||||
PathTrie::~PathTrie() {
|
||||
for (auto child : children_) {
|
||||
delete child.second;
|
||||
}
|
||||
}
|
||||
|
||||
PathTrie* PathTrie::get_path_trie(int new_char, bool reset) {
|
||||
auto child = children_.begin();
|
||||
for (child = children_.begin(); child != children_.end(); ++child) {
|
||||
if (child->first == new_char) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (child != children_.end()) {
|
||||
if (!child->second->exists_) {
|
||||
child->second->exists_ = true;
|
||||
child->second->log_prob_b_prev = -NUM_FLT_INF;
|
||||
child->second->log_prob_nb_prev = -NUM_FLT_INF;
|
||||
child->second->log_prob_b_cur = -NUM_FLT_INF;
|
||||
child->second->log_prob_nb_cur = -NUM_FLT_INF;
|
||||
}
|
||||
return (child->second);
|
||||
} else {
|
||||
if (has_dictionary_) {
|
||||
matcher_->SetState(dictionary_state_);
|
||||
bool found = matcher_->Find(new_char + 1);
|
||||
if (!found) {
|
||||
// Adding this character causes word outside dictionary
|
||||
auto FSTZERO = fst::TropicalWeight::Zero();
|
||||
auto final_weight = dictionary_->Final(dictionary_state_);
|
||||
bool is_final = (final_weight != FSTZERO);
|
||||
if (is_final && reset) {
|
||||
dictionary_state_ = dictionary_->Start();
|
||||
}
|
||||
return nullptr;
|
||||
} else {
|
||||
PathTrie* new_path = new PathTrie;
|
||||
new_path->character = new_char;
|
||||
new_path->parent = this;
|
||||
new_path->dictionary_ = dictionary_;
|
||||
new_path->dictionary_state_ = matcher_->Value().nextstate;
|
||||
new_path->has_dictionary_ = true;
|
||||
new_path->matcher_ = matcher_;
|
||||
children_.push_back(std::make_pair(new_char, new_path));
|
||||
return new_path;
|
||||
}
|
||||
} else {
|
||||
PathTrie* new_path = new PathTrie;
|
||||
new_path->character = new_char;
|
||||
new_path->parent = this;
|
||||
children_.push_back(std::make_pair(new_char, new_path));
|
||||
return new_path;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PathTrie* PathTrie::get_path_vec(std::vector<int>& output) {
|
||||
return get_path_vec(output, ROOT_);
|
||||
}
|
||||
|
||||
PathTrie* PathTrie::get_path_vec(std::vector<int>& output,
|
||||
int stop,
|
||||
size_t max_steps) {
|
||||
if (character == stop || character == ROOT_ || output.size() == max_steps) {
|
||||
std::reverse(output.begin(), output.end());
|
||||
return this;
|
||||
} else {
|
||||
output.push_back(character);
|
||||
return parent->get_path_vec(output, stop, max_steps);
|
||||
}
|
||||
}
|
||||
|
||||
void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
|
||||
if (exists_) {
|
||||
log_prob_b_prev = log_prob_b_cur;
|
||||
log_prob_nb_prev = log_prob_nb_cur;
|
||||
|
||||
log_prob_b_cur = -NUM_FLT_INF;
|
||||
log_prob_nb_cur = -NUM_FLT_INF;
|
||||
|
||||
score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
|
||||
output.push_back(this);
|
||||
}
|
||||
for (auto child : children_) {
|
||||
child.second->iterate_to_vec(output);
|
||||
}
|
||||
}
|
||||
|
||||
void PathTrie::remove() {
|
||||
exists_ = false;
|
||||
|
||||
if (children_.size() == 0) {
|
||||
auto child = parent->children_.begin();
|
||||
for (child = parent->children_.begin(); child != parent->children_.end();
|
||||
++child) {
|
||||
if (child->first == character) {
|
||||
parent->children_.erase(child);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (parent->children_.size() == 0 && !parent->exists_) {
|
||||
parent->remove();
|
||||
}
|
||||
|
||||
delete this;
|
||||
}
|
||||
}
|
||||
|
||||
void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) {
|
||||
dictionary_ = dictionary;
|
||||
dictionary_state_ = dictionary->Start();
|
||||
has_dictionary_ = true;
|
||||
}
|
||||
|
||||
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
||||
void PathTrie::set_matcher(std::shared_ptr<FSTMATCH> matcher) {
|
||||
matcher_ = matcher;
|
||||
}
|
@ -0,0 +1,67 @@
|
||||
#ifndef PATH_TRIE_H
|
||||
#define PATH_TRIE_H
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "fst/fstlib.h"
|
||||
|
||||
/* Trie tree for prefix storing and manipulating, with a dictionary in
|
||||
* finite-state transducer for spelling correction.
|
||||
*/
|
||||
class PathTrie {
|
||||
public:
|
||||
PathTrie();
|
||||
~PathTrie();
|
||||
|
||||
// get new prefix after appending new char
|
||||
PathTrie* get_path_trie(int new_char, bool reset = true);
|
||||
|
||||
// get the prefix in index from root to current node
|
||||
PathTrie* get_path_vec(std::vector<int>& output);
|
||||
|
||||
// get the prefix in index from some stop node to current nodel
|
||||
PathTrie* get_path_vec(std::vector<int>& output,
|
||||
int stop,
|
||||
size_t max_steps = std::numeric_limits<size_t>::max());
|
||||
|
||||
// update log probs
|
||||
void iterate_to_vec(std::vector<PathTrie*>& output);
|
||||
|
||||
// set dictionary for FST
|
||||
void set_dictionary(fst::StdVectorFst* dictionary);
|
||||
|
||||
void set_matcher(std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>>);
|
||||
|
||||
bool is_empty() { return ROOT_ == character; }
|
||||
|
||||
// remove current path from root
|
||||
void remove();
|
||||
|
||||
float log_prob_b_prev;
|
||||
float log_prob_nb_prev;
|
||||
float log_prob_b_cur;
|
||||
float log_prob_nb_cur;
|
||||
float score;
|
||||
float approx_ctc;
|
||||
int character;
|
||||
PathTrie* parent;
|
||||
|
||||
private:
|
||||
int ROOT_;
|
||||
bool exists_;
|
||||
bool has_dictionary_;
|
||||
|
||||
std::vector<std::pair<int, PathTrie*>> children_;
|
||||
|
||||
// pointer to dictionary of FST
|
||||
fst::StdVectorFst* dictionary_;
|
||||
fst::StdVectorFst::StateId dictionary_state_;
|
||||
// true if finding ars in FST
|
||||
std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>> matcher_;
|
||||
};
|
||||
|
||||
#endif // PATH_TRIE_H
|
@ -0,0 +1,230 @@
|
||||
#include "scorer.h"
|
||||
|
||||
#include <unistd.h>
|
||||
#include <iostream>
|
||||
|
||||
#include "lm/config.hh"
|
||||
#include "lm/model.hh"
|
||||
#include "lm/state.hh"
|
||||
#include "util/string_piece.hh"
|
||||
#include "util/tokenize_piece.hh"
|
||||
|
||||
#include "decoder_utils.h"
|
||||
|
||||
using namespace lm::ngram;
|
||||
|
||||
Scorer::Scorer(double alpha,
|
||||
double beta,
|
||||
const std::string& lm_path,
|
||||
const std::vector<std::string>& vocab_list) {
|
||||
this->alpha = alpha;
|
||||
this->beta = beta;
|
||||
|
||||
dictionary = nullptr;
|
||||
is_character_based_ = true;
|
||||
language_model_ = nullptr;
|
||||
|
||||
max_order_ = 0;
|
||||
dict_size_ = 0;
|
||||
SPACE_ID_ = -1;
|
||||
|
||||
setup(lm_path, vocab_list);
|
||||
}
|
||||
|
||||
Scorer::~Scorer() {
|
||||
if (language_model_ != nullptr) {
|
||||
delete static_cast<lm::base::Model*>(language_model_);
|
||||
}
|
||||
if (dictionary != nullptr) {
|
||||
delete static_cast<fst::StdVectorFst*>(dictionary);
|
||||
}
|
||||
}
|
||||
|
||||
void Scorer::setup(const std::string& lm_path,
|
||||
const std::vector<std::string>& vocab_list) {
|
||||
// load language model
|
||||
load_lm(lm_path);
|
||||
// set char map for scorer
|
||||
set_char_map(vocab_list);
|
||||
// fill the dictionary for FST
|
||||
if (!is_character_based()) {
|
||||
fill_dictionary(true);
|
||||
}
|
||||
}
|
||||
|
||||
void Scorer::load_lm(const std::string& lm_path) {
|
||||
const char* filename = lm_path.c_str();
|
||||
VALID_CHECK_EQ(access(filename, F_OK), 0, "Invalid language model path");
|
||||
|
||||
RetriveStrEnumerateVocab enumerate;
|
||||
lm::ngram::Config config;
|
||||
config.enumerate_vocab = &enumerate;
|
||||
language_model_ = lm::ngram::LoadVirtual(filename, config);
|
||||
max_order_ = static_cast<lm::base::Model*>(language_model_)->Order();
|
||||
vocabulary_ = enumerate.vocabulary;
|
||||
for (size_t i = 0; i < vocabulary_.size(); ++i) {
|
||||
if (is_character_based_ && vocabulary_[i] != UNK_TOKEN &&
|
||||
vocabulary_[i] != START_TOKEN && vocabulary_[i] != END_TOKEN &&
|
||||
get_utf8_str_len(enumerate.vocabulary[i]) > 1) {
|
||||
is_character_based_ = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
|
||||
lm::base::Model* model = static_cast<lm::base::Model*>(language_model_);
|
||||
double cond_prob;
|
||||
lm::ngram::State state, tmp_state, out_state;
|
||||
// avoid to inserting <s> in begin
|
||||
model->NullContextWrite(&state);
|
||||
for (size_t i = 0; i < words.size(); ++i) {
|
||||
lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]);
|
||||
// encounter OOV
|
||||
if (word_index == 0) {
|
||||
return OOV_SCORE;
|
||||
}
|
||||
cond_prob = model->BaseScore(&state, word_index, &out_state);
|
||||
tmp_state = state;
|
||||
state = out_state;
|
||||
out_state = tmp_state;
|
||||
}
|
||||
// return log10 prob
|
||||
return cond_prob;
|
||||
}
|
||||
|
||||
double Scorer::get_sent_log_prob(const std::vector<std::string>& words) {
|
||||
std::vector<std::string> sentence;
|
||||
if (words.size() == 0) {
|
||||
for (size_t i = 0; i < max_order_; ++i) {
|
||||
sentence.push_back(START_TOKEN);
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < max_order_ - 1; ++i) {
|
||||
sentence.push_back(START_TOKEN);
|
||||
}
|
||||
sentence.insert(sentence.end(), words.begin(), words.end());
|
||||
}
|
||||
sentence.push_back(END_TOKEN);
|
||||
return get_log_prob(sentence);
|
||||
}
|
||||
|
||||
double Scorer::get_log_prob(const std::vector<std::string>& words) {
|
||||
assert(words.size() > max_order_);
|
||||
double score = 0.0;
|
||||
for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) {
|
||||
std::vector<std::string> ngram(words.begin() + i,
|
||||
words.begin() + i + max_order_);
|
||||
score += get_log_cond_prob(ngram);
|
||||
}
|
||||
return score;
|
||||
}
|
||||
|
||||
void Scorer::reset_params(float alpha, float beta) {
|
||||
this->alpha = alpha;
|
||||
this->beta = beta;
|
||||
}
|
||||
|
||||
std::string Scorer::vec2str(const std::vector<int>& input) {
|
||||
std::string word;
|
||||
for (auto ind : input) {
|
||||
word += char_list_[ind];
|
||||
}
|
||||
return word;
|
||||
}
|
||||
|
||||
std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
|
||||
if (labels.empty()) return {};
|
||||
|
||||
std::string s = vec2str(labels);
|
||||
std::vector<std::string> words;
|
||||
if (is_character_based_) {
|
||||
words = split_utf8_str(s);
|
||||
} else {
|
||||
words = split_str(s, " ");
|
||||
}
|
||||
return words;
|
||||
}
|
||||
|
||||
void Scorer::set_char_map(const std::vector<std::string>& char_list) {
|
||||
char_list_ = char_list;
|
||||
char_map_.clear();
|
||||
|
||||
// Set the char map for the FST for spelling correction
|
||||
for (size_t i = 0; i < char_list_.size(); i++) {
|
||||
if (char_list_[i] == " ") {
|
||||
SPACE_ID_ = i;
|
||||
}
|
||||
// The initial state of FST is state 0, hence the index of chars in
|
||||
// the FST should start from 1 to avoid the conflict with the initial
|
||||
// state, otherwise wrong decoding results would be given.
|
||||
char_map_[char_list_[i]] = i + 1;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
|
||||
std::vector<std::string> ngram;
|
||||
PathTrie* current_node = prefix;
|
||||
PathTrie* new_node = nullptr;
|
||||
|
||||
for (int order = 0; order < max_order_; order++) {
|
||||
std::vector<int> prefix_vec;
|
||||
|
||||
if (is_character_based_) {
|
||||
new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_, 1);
|
||||
current_node = new_node;
|
||||
} else {
|
||||
new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_);
|
||||
current_node = new_node->parent; // Skipping spaces
|
||||
}
|
||||
|
||||
// reconstruct word
|
||||
std::string word = vec2str(prefix_vec);
|
||||
ngram.push_back(word);
|
||||
|
||||
if (new_node->character == -1) {
|
||||
// No more spaces, but still need order
|
||||
for (int i = 0; i < max_order_ - order - 1; i++) {
|
||||
ngram.push_back(START_TOKEN);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
std::reverse(ngram.begin(), ngram.end());
|
||||
return ngram;
|
||||
}
|
||||
|
||||
void Scorer::fill_dictionary(bool add_space) {
|
||||
fst::StdVectorFst dictionary;
|
||||
// For each unigram convert to ints and put in trie
|
||||
int dict_size = 0;
|
||||
for (const auto& word : vocabulary_) {
|
||||
bool added = add_word_to_dictionary(
|
||||
word, char_map_, add_space, SPACE_ID_ + 1, &dictionary);
|
||||
dict_size += added ? 1 : 0;
|
||||
}
|
||||
|
||||
dict_size_ = dict_size;
|
||||
|
||||
/* Simplify FST
|
||||
|
||||
* This gets rid of "epsilon" transitions in the FST.
|
||||
* These are transitions that don't require a string input to be taken.
|
||||
* Getting rid of them is necessary to make the FST determinisitc, but
|
||||
* can greatly increase the size of the FST
|
||||
*/
|
||||
fst::RmEpsilon(&dictionary);
|
||||
fst::StdVectorFst* new_dict = new fst::StdVectorFst;
|
||||
|
||||
/* This makes the FST deterministic, meaning for any string input there's
|
||||
* only one possible state the FST could be in. It is assumed our
|
||||
* dictionary is deterministic when using it.
|
||||
* (lest we'd have to check for multiple transitions at each state)
|
||||
*/
|
||||
fst::Determinize(dictionary, new_dict);
|
||||
|
||||
/* Finds the simplest equivalent fst. This is unnecessary but decreases
|
||||
* memory usage of the dictionary
|
||||
*/
|
||||
fst::Minimize(new_dict);
|
||||
this->dictionary = new_dict;
|
||||
}
|
@ -0,0 +1,112 @@
|
||||
#ifndef SCORER_H_
|
||||
#define SCORER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "lm/enumerate_vocab.hh"
|
||||
#include "lm/virtual_interface.hh"
|
||||
#include "lm/word_index.hh"
|
||||
#include "util/string_piece.hh"
|
||||
|
||||
#include "path_trie.h"
|
||||
|
||||
const double OOV_SCORE = -1000.0;
|
||||
const std::string START_TOKEN = "<s>";
|
||||
const std::string UNK_TOKEN = "<unk>";
|
||||
const std::string END_TOKEN = "</s>";
|
||||
|
||||
// Implement a callback to retrive the dictionary of language model.
|
||||
class RetriveStrEnumerateVocab : public lm::EnumerateVocab {
|
||||
public:
|
||||
RetriveStrEnumerateVocab() {}
|
||||
|
||||
void Add(lm::WordIndex index, const StringPiece &str) {
|
||||
vocabulary.push_back(std::string(str.data(), str.length()));
|
||||
}
|
||||
|
||||
std::vector<std::string> vocabulary;
|
||||
};
|
||||
|
||||
/* External scorer to query score for n-gram or sentence, including language
|
||||
* model scoring and word insertion.
|
||||
*
|
||||
* Example:
|
||||
* Scorer scorer(alpha, beta, "path_of_language_model");
|
||||
* scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" });
|
||||
* scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
|
||||
*/
|
||||
class Scorer {
|
||||
public:
|
||||
Scorer(double alpha,
|
||||
double beta,
|
||||
const std::string &lm_path,
|
||||
const std::vector<std::string> &vocabulary);
|
||||
~Scorer();
|
||||
|
||||
double get_log_cond_prob(const std::vector<std::string> &words);
|
||||
|
||||
double get_sent_log_prob(const std::vector<std::string> &words);
|
||||
|
||||
// return the max order
|
||||
size_t get_max_order() const { return max_order_; }
|
||||
|
||||
// return the dictionary size of language model
|
||||
size_t get_dict_size() const { return dict_size_; }
|
||||
|
||||
// retrun true if the language model is character based
|
||||
bool is_character_based() const { return is_character_based_; }
|
||||
|
||||
// reset params alpha & beta
|
||||
void reset_params(float alpha, float beta);
|
||||
|
||||
// make ngram for a given prefix
|
||||
std::vector<std::string> make_ngram(PathTrie *prefix);
|
||||
|
||||
// trransform the labels in index to the vector of words (word based lm) or
|
||||
// the vector of characters (character based lm)
|
||||
std::vector<std::string> split_labels(const std::vector<int> &labels);
|
||||
|
||||
// language model weight
|
||||
double alpha;
|
||||
// word insertion weight
|
||||
double beta;
|
||||
|
||||
// pointer to the dictionary of FST
|
||||
void *dictionary;
|
||||
|
||||
protected:
|
||||
// necessary setup: load language model, set char map, fill FST's dictionary
|
||||
void setup(const std::string &lm_path,
|
||||
const std::vector<std::string> &vocab_list);
|
||||
|
||||
// load language model from given path
|
||||
void load_lm(const std::string &lm_path);
|
||||
|
||||
// fill dictionary for FST
|
||||
void fill_dictionary(bool add_space);
|
||||
|
||||
// set char map
|
||||
void set_char_map(const std::vector<std::string> &char_list);
|
||||
|
||||
double get_log_prob(const std::vector<std::string> &words);
|
||||
|
||||
// translate the vector in index to string
|
||||
std::string vec2str(const std::vector<int> &input);
|
||||
|
||||
private:
|
||||
void *language_model_;
|
||||
bool is_character_based_;
|
||||
size_t max_order_;
|
||||
size_t dict_size_;
|
||||
|
||||
int SPACE_ID_;
|
||||
std::vector<std::string> char_list_;
|
||||
std::unordered_map<std::string, int> char_map_;
|
||||
|
||||
std::vector<std::string> vocabulary_;
|
||||
};
|
||||
|
||||
#endif // SCORER_H_
|
@ -0,0 +1,119 @@
|
||||
"""Script to build and install decoder package."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from setuptools import setup, Extension, distutils
|
||||
import glob
|
||||
import platform
|
||||
import os, sys
|
||||
import multiprocessing.pool
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--num_processes",
|
||||
default=1,
|
||||
type=int,
|
||||
help="Number of cpu processes to build package. (default: %(default)d)")
|
||||
args = parser.parse_known_args()
|
||||
|
||||
# reconstruct sys.argv to pass to setup below
|
||||
sys.argv = [sys.argv[0]] + args[1]
|
||||
|
||||
|
||||
# monkey-patch for parallel compilation
|
||||
# See: https://stackoverflow.com/a/13176803
|
||||
def parallelCCompile(self,
|
||||
sources,
|
||||
output_dir=None,
|
||||
macros=None,
|
||||
include_dirs=None,
|
||||
debug=0,
|
||||
extra_preargs=None,
|
||||
extra_postargs=None,
|
||||
depends=None):
|
||||
# those lines are copied from distutils.ccompiler.CCompiler directly
|
||||
macros, objects, extra_postargs, pp_opts, build = self._setup_compile(
|
||||
output_dir, macros, include_dirs, sources, depends, extra_postargs)
|
||||
cc_args = self._get_cc_args(pp_opts, debug, extra_preargs)
|
||||
|
||||
# parallel code
|
||||
def _single_compile(obj):
|
||||
try:
|
||||
src, ext = build[obj]
|
||||
except KeyError:
|
||||
return
|
||||
self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts)
|
||||
|
||||
# convert to list, imap is evaluated on-demand
|
||||
thread_pool = multiprocessing.pool.ThreadPool(args[0].num_processes)
|
||||
list(thread_pool.imap(_single_compile, objects))
|
||||
return objects
|
||||
|
||||
|
||||
def compile_test(header, library):
|
||||
dummy_path = os.path.join(os.path.dirname(__file__), "dummy")
|
||||
command = "bash -c \"g++ -include " + header \
|
||||
+ " -l" + library + " -x c++ - <<<'int main() {}' -o " \
|
||||
+ dummy_path + " >/dev/null 2>/dev/null && rm " \
|
||||
+ dummy_path + " 2>/dev/null\""
|
||||
return os.system(command) == 0
|
||||
|
||||
|
||||
# hack compile to support parallel compiling
|
||||
distutils.ccompiler.CCompiler.compile = parallelCCompile
|
||||
|
||||
FILES = glob.glob('kenlm/util/*.cc') \
|
||||
+ glob.glob('kenlm/lm/*.cc') \
|
||||
+ glob.glob('kenlm/util/double-conversion/*.cc')
|
||||
|
||||
FILES += glob.glob('openfst-1.6.3/src/lib/*.cc')
|
||||
|
||||
FILES = [
|
||||
fn for fn in FILES
|
||||
if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith(
|
||||
'unittest.cc'))
|
||||
]
|
||||
|
||||
LIBS = ['stdc++']
|
||||
if platform.system() != 'Darwin':
|
||||
LIBS.append('rt')
|
||||
|
||||
ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER=6', '-std=c++11']
|
||||
|
||||
if compile_test('zlib.h', 'z'):
|
||||
ARGS.append('-DHAVE_ZLIB')
|
||||
LIBS.append('z')
|
||||
|
||||
if compile_test('bzlib.h', 'bz2'):
|
||||
ARGS.append('-DHAVE_BZLIB')
|
||||
LIBS.append('bz2')
|
||||
|
||||
if compile_test('lzma.h', 'lzma'):
|
||||
ARGS.append('-DHAVE_XZLIB')
|
||||
LIBS.append('lzma')
|
||||
|
||||
os.system('swig -python -c++ ./decoders.i')
|
||||
|
||||
decoders_module = [
|
||||
Extension(
|
||||
name='_swig_decoders',
|
||||
sources=FILES + glob.glob('*.cxx') + glob.glob('*.cpp'),
|
||||
language='c++',
|
||||
include_dirs=[
|
||||
'.',
|
||||
'kenlm',
|
||||
'openfst-1.6.3/src/include',
|
||||
'ThreadPool',
|
||||
],
|
||||
libraries=LIBS,
|
||||
extra_compile_args=ARGS)
|
||||
]
|
||||
|
||||
setup(
|
||||
name='swig_decoders',
|
||||
version='1.1',
|
||||
description="""CTC decoders""",
|
||||
ext_modules=decoders_module,
|
||||
py_modules=['swig_decoders'], )
|
@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
if [ ! -d kenlm ]; then
|
||||
git clone https://github.com/luotao1/kenlm.git
|
||||
echo -e "\n"
|
||||
fi
|
||||
|
||||
if [ ! -d openfst-1.6.3 ]; then
|
||||
echo "Download and extract openfst ..."
|
||||
wget http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz
|
||||
tar -xzvf openfst-1.6.3.tar.gz
|
||||
echo -e "\n"
|
||||
fi
|
||||
|
||||
if [ ! -d ThreadPool ]; then
|
||||
git clone https://github.com/progschj/ThreadPool.git
|
||||
echo -e "\n"
|
||||
fi
|
||||
|
||||
echo "Install decoders ..."
|
||||
python setup.py install --num_processes 4
|
@ -0,0 +1,124 @@
|
||||
"""Wrapper for various CTC decoders in SWIG."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import swig_decoders
|
||||
|
||||
|
||||
class Scorer(swig_decoders.Scorer):
|
||||
"""Wrapper for Scorer.
|
||||
|
||||
:param alpha: Parameter associated with language model. Don't use
|
||||
language model when alpha = 0.
|
||||
:type alpha: float
|
||||
:param beta: Parameter associated with word count. Don't use word
|
||||
count when beta = 0.
|
||||
:type beta: float
|
||||
:model_path: Path to load language model.
|
||||
:type model_path: basestring
|
||||
"""
|
||||
|
||||
def __init__(self, alpha, beta, model_path, vocabulary):
|
||||
swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary)
|
||||
|
||||
|
||||
def ctc_greedy_decoder(probs_seq, vocabulary):
|
||||
"""Wrapper for ctc best path decoder in swig.
|
||||
|
||||
:param probs_seq: 2-D list of probability distributions over each time
|
||||
step, with each element being a list of normalized
|
||||
probabilities over vocabulary and blank.
|
||||
:type probs_seq: 2-D list
|
||||
:param vocabulary: Vocabulary list.
|
||||
:type vocabulary: list
|
||||
:return: Decoding result string.
|
||||
:rtype: basestring
|
||||
"""
|
||||
result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary)
|
||||
return result.decode('utf-8')
|
||||
|
||||
|
||||
def ctc_beam_search_decoder(probs_seq,
|
||||
vocabulary,
|
||||
beam_size,
|
||||
cutoff_prob=1.0,
|
||||
cutoff_top_n=40,
|
||||
ext_scoring_func=None):
|
||||
"""Wrapper for the CTC Beam Search Decoder.
|
||||
|
||||
:param probs_seq: 2-D list of probability distributions over each time
|
||||
step, with each element being a list of normalized
|
||||
probabilities over vocabulary and blank.
|
||||
:type probs_seq: 2-D list
|
||||
:param vocabulary: Vocabulary list.
|
||||
:type vocabulary: list
|
||||
:param beam_size: Width for beam search.
|
||||
:type beam_size: int
|
||||
:param cutoff_prob: Cutoff probability in pruning,
|
||||
default 1.0, no pruning.
|
||||
:type cutoff_prob: float
|
||||
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
|
||||
characters with highest probs in vocabulary will be
|
||||
used in beam search, default 40.
|
||||
:type cutoff_top_n: int
|
||||
:param ext_scoring_func: External scoring function for
|
||||
partially decoded sentence, e.g. word count
|
||||
or language model.
|
||||
:type external_scoring_func: callable
|
||||
:return: List of tuples of log probability and sentence as decoding
|
||||
results, in descending order of the probability.
|
||||
:rtype: list
|
||||
"""
|
||||
beam_results = swig_decoders.ctc_beam_search_decoder(
|
||||
probs_seq.tolist(), vocabulary, beam_size, cutoff_prob, cutoff_top_n,
|
||||
ext_scoring_func)
|
||||
beam_results = [(res[0], res[1].decode('utf-8')) for res in beam_results]
|
||||
return beam_results
|
||||
|
||||
|
||||
def ctc_beam_search_decoder_batch(probs_split,
|
||||
vocabulary,
|
||||
beam_size,
|
||||
num_processes,
|
||||
cutoff_prob=1.0,
|
||||
cutoff_top_n=40,
|
||||
ext_scoring_func=None):
|
||||
"""Wrapper for the batched CTC beam search decoder.
|
||||
|
||||
:param probs_seq: 3-D list with each element as an instance of 2-D list
|
||||
of probabilities used by ctc_beam_search_decoder().
|
||||
:type probs_seq: 3-D list
|
||||
:param vocabulary: Vocabulary list.
|
||||
:type vocabulary: list
|
||||
:param beam_size: Width for beam search.
|
||||
:type beam_size: int
|
||||
:param num_processes: Number of parallel processes.
|
||||
:type num_processes: int
|
||||
:param cutoff_prob: Cutoff probability in vocabulary pruning,
|
||||
default 1.0, no pruning.
|
||||
:type cutoff_prob: float
|
||||
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
|
||||
characters with highest probs in vocabulary will be
|
||||
used in beam search, default 40.
|
||||
:type cutoff_top_n: int
|
||||
:param num_processes: Number of parallel processes.
|
||||
:type num_processes: int
|
||||
:param ext_scoring_func: External scoring function for
|
||||
partially decoded sentence, e.g. word count
|
||||
or language model.
|
||||
:type external_scoring_function: callable
|
||||
:return: List of tuples of log probability and sentence as decoding
|
||||
results, in descending order of the probability.
|
||||
:rtype: list
|
||||
"""
|
||||
probs_split = [probs_seq.tolist() for probs_seq in probs_split]
|
||||
|
||||
batch_beam_results = swig_decoders.ctc_beam_search_decoder_batch(
|
||||
probs_split, vocabulary, beam_size, num_processes, cutoff_prob,
|
||||
cutoff_top_n, ext_scoring_func)
|
||||
batch_beam_results = [
|
||||
[(res[0], res[1].decode("utf-8")) for res in beam_results]
|
||||
for beam_results in batch_beam_results
|
||||
]
|
||||
return batch_beam_results
|
@ -0,0 +1,90 @@
|
||||
"""Test decoders."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
from decoders import decoders_deprecated as decoder
|
||||
|
||||
|
||||
class TestDecoders(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.vocab_list = ["\'", ' ', 'a', 'b', 'c', 'd']
|
||||
self.beam_size = 20
|
||||
self.probs_seq1 = [[
|
||||
0.06390443, 0.21124858, 0.27323887, 0.06870235, 0.0361254,
|
||||
0.18184413, 0.16493624
|
||||
], [
|
||||
0.03309247, 0.22866108, 0.24390638, 0.09699597, 0.31895462,
|
||||
0.0094893, 0.06890021
|
||||
], [
|
||||
0.218104, 0.19992557, 0.18245131, 0.08503348, 0.14903535,
|
||||
0.08424043, 0.08120984
|
||||
], [
|
||||
0.12094152, 0.19162472, 0.01473646, 0.28045061, 0.24246305,
|
||||
0.05206269, 0.09772094
|
||||
], [
|
||||
0.1333387, 0.00550838, 0.00301669, 0.21745861, 0.20803985,
|
||||
0.41317442, 0.01946335
|
||||
], [
|
||||
0.16468227, 0.1980699, 0.1906545, 0.18963251, 0.19860937,
|
||||
0.04377724, 0.01457421
|
||||
]]
|
||||
self.probs_seq2 = [[
|
||||
0.08034842, 0.22671944, 0.05799633, 0.36814645, 0.11307441,
|
||||
0.04468023, 0.10903471
|
||||
], [
|
||||
0.09742457, 0.12959763, 0.09435383, 0.21889204, 0.15113123,
|
||||
0.10219457, 0.20640612
|
||||
], [
|
||||
0.45033529, 0.09091417, 0.15333208, 0.07939558, 0.08649316,
|
||||
0.12298585, 0.01654384
|
||||
], [
|
||||
0.02512238, 0.22079203, 0.19664364, 0.11906379, 0.07816055,
|
||||
0.22538587, 0.13483174
|
||||
], [
|
||||
0.17928453, 0.06065261, 0.41153005, 0.1172041, 0.11880313,
|
||||
0.07113197, 0.04139363
|
||||
], [
|
||||
0.15882358, 0.1235788, 0.23376776, 0.20510435, 0.00279306,
|
||||
0.05294827, 0.22298418
|
||||
]]
|
||||
self.greedy_result = ["ac'bdc", "b'da"]
|
||||
self.beam_search_result = ['acdc', "b'a"]
|
||||
|
||||
def test_greedy_decoder_1(self):
|
||||
bst_result = decoder.ctc_greedy_decoder(self.probs_seq1,
|
||||
self.vocab_list)
|
||||
self.assertEqual(bst_result, self.greedy_result[0])
|
||||
|
||||
def test_greedy_decoder_2(self):
|
||||
bst_result = decoder.ctc_greedy_decoder(self.probs_seq2,
|
||||
self.vocab_list)
|
||||
self.assertEqual(bst_result, self.greedy_result[1])
|
||||
|
||||
def test_beam_search_decoder_1(self):
|
||||
beam_result = decoder.ctc_beam_search_decoder(
|
||||
probs_seq=self.probs_seq1,
|
||||
beam_size=self.beam_size,
|
||||
vocabulary=self.vocab_list)
|
||||
self.assertEqual(beam_result[0][1], self.beam_search_result[0])
|
||||
|
||||
def test_beam_search_decoder_2(self):
|
||||
beam_result = decoder.ctc_beam_search_decoder(
|
||||
probs_seq=self.probs_seq2,
|
||||
beam_size=self.beam_size,
|
||||
vocabulary=self.vocab_list)
|
||||
self.assertEqual(beam_result[0][1], self.beam_search_result[1])
|
||||
|
||||
def test_beam_search_decoder_batch(self):
|
||||
beam_results = decoder.ctc_beam_search_decoder_batch(
|
||||
probs_split=[self.probs_seq1, self.probs_seq2],
|
||||
beam_size=self.beam_size,
|
||||
vocabulary=self.vocab_list,
|
||||
num_processes=24)
|
||||
self.assertEqual(beam_results[0][0][1], self.beam_search_result[0])
|
||||
self.assertEqual(beam_results[1][0][1], self.beam_search_result[1])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,19 @@
|
||||
"""Set up paths for DS2"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os.path
|
||||
import sys
|
||||
|
||||
|
||||
def add_path(path):
|
||||
if path not in sys.path:
|
||||
sys.path.insert(0, path)
|
||||
|
||||
|
||||
this_dir = os.path.dirname(__file__)
|
||||
|
||||
# Add project path to PYTHONPATH
|
||||
proj_path = os.path.join(this_dir, '..')
|
||||
add_path(proj_path)
|
@ -0,0 +1,94 @@
|
||||
"""Client-end for the ASR demo."""
|
||||
from pynput import keyboard
|
||||
import struct
|
||||
import socket
|
||||
import sys
|
||||
import argparse
|
||||
import pyaudio
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--host_ip",
|
||||
default="localhost",
|
||||
type=str,
|
||||
help="Server IP address. (default: %(default)s)")
|
||||
parser.add_argument(
|
||||
"--host_port",
|
||||
default=8086,
|
||||
type=int,
|
||||
help="Server Port. (default: %(default)s)")
|
||||
args = parser.parse_args()
|
||||
|
||||
is_recording = False
|
||||
enable_trigger_record = True
|
||||
|
||||
|
||||
def on_press(key):
|
||||
"""On-press keyboard callback function."""
|
||||
global is_recording, enable_trigger_record
|
||||
if key == keyboard.Key.space:
|
||||
if (not is_recording) and enable_trigger_record:
|
||||
sys.stdout.write("Start Recording ... ")
|
||||
sys.stdout.flush()
|
||||
is_recording = True
|
||||
|
||||
|
||||
def on_release(key):
|
||||
"""On-release keyboard callback function."""
|
||||
global is_recording, enable_trigger_record
|
||||
if key == keyboard.Key.esc:
|
||||
return False
|
||||
elif key == keyboard.Key.space:
|
||||
if is_recording == True:
|
||||
is_recording = False
|
||||
|
||||
|
||||
data_list = []
|
||||
|
||||
|
||||
def callback(in_data, frame_count, time_info, status):
|
||||
"""Audio recorder's stream callback function."""
|
||||
global data_list, is_recording, enable_trigger_record
|
||||
if is_recording:
|
||||
data_list.append(in_data)
|
||||
enable_trigger_record = False
|
||||
elif len(data_list) > 0:
|
||||
# Connect to server and send data
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.connect((args.host_ip, args.host_port))
|
||||
sent = ''.join(data_list)
|
||||
sock.sendall(struct.pack('>i', len(sent)) + sent)
|
||||
print('Speech[length=%d] Sent.' % len(sent))
|
||||
# Receive data from the server and shut down
|
||||
received = sock.recv(1024)
|
||||
print "Recognition Results: {}".format(received)
|
||||
sock.close()
|
||||
data_list = []
|
||||
enable_trigger_record = True
|
||||
return (in_data, pyaudio.paContinue)
|
||||
|
||||
|
||||
def main():
|
||||
# prepare audio recorder
|
||||
p = pyaudio.PyAudio()
|
||||
stream = p.open(
|
||||
format=pyaudio.paInt32,
|
||||
channels=1,
|
||||
rate=16000,
|
||||
input=True,
|
||||
stream_callback=callback)
|
||||
stream.start_stream()
|
||||
|
||||
# prepare keyboard listener
|
||||
with keyboard.Listener(
|
||||
on_press=on_press, on_release=on_release) as listener:
|
||||
listener.join()
|
||||
|
||||
# close up
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,215 @@
|
||||
"""Server-end for the ASR demo."""
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
import argparse
|
||||
import functools
|
||||
from time import gmtime, strftime
|
||||
import SocketServer
|
||||
import struct
|
||||
import wave
|
||||
import paddle.v2 as paddle
|
||||
import _init_paths
|
||||
from data_utils.data import DataGenerator
|
||||
from model_utils.model import DeepSpeech2Model
|
||||
from data_utils.utility import read_manifest
|
||||
from utils.utility import add_arguments, print_arguments
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
add_arg = functools.partial(add_arguments, argparser=parser)
|
||||
# yapf: disable
|
||||
add_arg('host_port', int, 8086, "Server's IP port.")
|
||||
add_arg('beam_size', int, 500, "Beam search width.")
|
||||
add_arg('num_conv_layers', int, 2, "# of convolution layers.")
|
||||
add_arg('num_rnn_layers', int, 3, "# of recurrent layers.")
|
||||
add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.")
|
||||
add_arg('alpha', float, 2.5, "Coef of LM for beam search.")
|
||||
add_arg('beta', float, 0.3, "Coef of WC for beam search.")
|
||||
add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.")
|
||||
add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.")
|
||||
add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.")
|
||||
add_arg('use_gpu', bool, True, "Use GPU or not.")
|
||||
add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across "
|
||||
"bi-directional RNNs. Not for GRU.")
|
||||
add_arg('host_ip', str,
|
||||
'localhost',
|
||||
"Server's IP address.")
|
||||
add_arg('speech_save_dir', str,
|
||||
'demo_cache',
|
||||
"Directory to save demo audios.")
|
||||
add_arg('warmup_manifest', str,
|
||||
'data/librispeech/manifest.test-clean',
|
||||
"Filepath of manifest to warm up.")
|
||||
add_arg('mean_std_path', str,
|
||||
'data/librispeech/mean_std.npz',
|
||||
"Filepath of normalizer's mean & std.")
|
||||
add_arg('vocab_path', str,
|
||||
'data/librispeech/eng_vocab.txt',
|
||||
"Filepath of vocabulary.")
|
||||
add_arg('model_path', str,
|
||||
'./checkpoints/libri/params.latest.tar.gz',
|
||||
"If None, the training starts from scratch, "
|
||||
"otherwise, it resumes from the pre-trained model.")
|
||||
add_arg('lang_model_path', str,
|
||||
'lm/data/common_crawl_00.prune01111.trie.klm',
|
||||
"Filepath for language model.")
|
||||
add_arg('decoding_method', str,
|
||||
'ctc_beam_search',
|
||||
"Decoding method. Options: ctc_beam_search, ctc_greedy",
|
||||
choices = ['ctc_beam_search', 'ctc_greedy'])
|
||||
add_arg('specgram_type', str,
|
||||
'linear',
|
||||
"Audio feature type. Options: linear, mfcc.",
|
||||
choices=['linear', 'mfcc'])
|
||||
# yapf: disable
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
class AsrTCPServer(SocketServer.TCPServer):
|
||||
"""The ASR TCP Server."""
|
||||
|
||||
def __init__(self,
|
||||
server_address,
|
||||
RequestHandlerClass,
|
||||
speech_save_dir,
|
||||
audio_process_handler,
|
||||
bind_and_activate=True):
|
||||
self.speech_save_dir = speech_save_dir
|
||||
self.audio_process_handler = audio_process_handler
|
||||
SocketServer.TCPServer.__init__(
|
||||
self, server_address, RequestHandlerClass, bind_and_activate=True)
|
||||
|
||||
|
||||
class AsrRequestHandler(SocketServer.BaseRequestHandler):
|
||||
"""The ASR request handler."""
|
||||
|
||||
def handle(self):
|
||||
# receive data through TCP socket
|
||||
chunk = self.request.recv(1024)
|
||||
target_len = struct.unpack('>i', chunk[:4])[0]
|
||||
data = chunk[4:]
|
||||
while len(data) < target_len:
|
||||
chunk = self.request.recv(1024)
|
||||
data += chunk
|
||||
# write to file
|
||||
filename = self._write_to_file(data)
|
||||
|
||||
print("Received utterance[length=%d] from %s, saved to %s." %
|
||||
(len(data), self.client_address[0], filename))
|
||||
start_time = time.time()
|
||||
transcript = self.server.audio_process_handler(filename)
|
||||
finish_time = time.time()
|
||||
print("Response Time: %f, Transcript: %s" %
|
||||
(finish_time - start_time, transcript))
|
||||
self.request.sendall(transcript.encode('utf-8'))
|
||||
|
||||
def _write_to_file(self, data):
|
||||
# prepare save dir and filename
|
||||
if not os.path.exists(self.server.speech_save_dir):
|
||||
os.mkdir(self.server.speech_save_dir)
|
||||
timestamp = strftime("%Y%m%d%H%M%S", gmtime())
|
||||
out_filename = os.path.join(
|
||||
self.server.speech_save_dir,
|
||||
timestamp + "_" + self.client_address[0] + ".wav")
|
||||
# write to wav file
|
||||
file = wave.open(out_filename, 'wb')
|
||||
file.setnchannels(1)
|
||||
file.setsampwidth(4)
|
||||
file.setframerate(16000)
|
||||
file.writeframes(data)
|
||||
file.close()
|
||||
return out_filename
|
||||
|
||||
|
||||
def warm_up_test(audio_process_handler,
|
||||
manifest_path,
|
||||
num_test_cases,
|
||||
random_seed=0):
|
||||
"""Warming-up test."""
|
||||
manifest = read_manifest(manifest_path)
|
||||
rng = random.Random(random_seed)
|
||||
samples = rng.sample(manifest, num_test_cases)
|
||||
for idx, sample in enumerate(samples):
|
||||
print("Warm-up Test Case %d: %s", idx, sample['audio_filepath'])
|
||||
start_time = time.time()
|
||||
transcript = audio_process_handler(sample['audio_filepath'])
|
||||
finish_time = time.time()
|
||||
print("Response Time: %f, Transcript: %s" %
|
||||
(finish_time - start_time, transcript))
|
||||
|
||||
|
||||
def start_server():
|
||||
"""Start the ASR server"""
|
||||
# prepare data generator
|
||||
data_generator = DataGenerator(
|
||||
vocab_filepath=args.vocab_path,
|
||||
mean_std_filepath=args.mean_std_path,
|
||||
augmentation_config='{}',
|
||||
specgram_type=args.specgram_type,
|
||||
num_threads=1,
|
||||
keep_transcription_text=True)
|
||||
# prepare ASR model
|
||||
ds2_model = DeepSpeech2Model(
|
||||
vocab_size=data_generator.vocab_size,
|
||||
num_conv_layers=args.num_conv_layers,
|
||||
num_rnn_layers=args.num_rnn_layers,
|
||||
rnn_layer_size=args.rnn_layer_size,
|
||||
use_gru=args.use_gru,
|
||||
pretrained_model_path=args.model_path,
|
||||
share_rnn_weights=args.share_rnn_weights)
|
||||
|
||||
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
|
||||
|
||||
if args.decoding_method == "ctc_beam_search":
|
||||
ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path,
|
||||
vocab_list)
|
||||
# prepare ASR inference handler
|
||||
def file_to_transcript(filename):
|
||||
feature = data_generator.process_utterance(filename, "")
|
||||
probs_split = ds2_model.infer_batch_probs(
|
||||
infer_data=[feature],
|
||||
feeding_dict=data_generator.feeding)
|
||||
|
||||
if args.decoding_method == "ctc_greedy":
|
||||
result_transcript = ds2_model.decode_batch_greedy(
|
||||
probs_split=probs_split,
|
||||
vocab_list=vocab_list)
|
||||
else:
|
||||
result_transcript = ds2_model.decode_batch_beam_search(
|
||||
probs_split=probs_split,
|
||||
beam_alpha=args.alpha,
|
||||
beam_beta=args.beta,
|
||||
beam_size=args.beam_size,
|
||||
cutoff_prob=args.cutoff_prob,
|
||||
cutoff_top_n=args.cutoff_top_n,
|
||||
vocab_list=vocab_list,
|
||||
num_processes=1)
|
||||
return result_transcript[0]
|
||||
|
||||
# warming up with utterrances sampled from Librispeech
|
||||
print('-----------------------------------------------------------')
|
||||
print('Warming up ...')
|
||||
warm_up_test(
|
||||
audio_process_handler=file_to_transcript,
|
||||
manifest_path=args.warmup_manifest,
|
||||
num_test_cases=3)
|
||||
print('-----------------------------------------------------------')
|
||||
|
||||
# start the server
|
||||
server = AsrTCPServer(
|
||||
server_address=(args.host_ip, args.host_port),
|
||||
RequestHandlerClass=AsrRequestHandler,
|
||||
speech_save_dir=args.speech_save_dir,
|
||||
audio_process_handler=file_to_transcript)
|
||||
print("ASR Server Started.")
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
def main():
|
||||
print_arguments(args)
|
||||
paddle.init(use_gpu=args.use_gpu, trainer_count=1)
|
||||
start_server()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
After Width: | Height: | Size: 153 KiB |
After Width: | Height: | Size: 108 KiB |
@ -0,0 +1,42 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
cd ../.. > /dev/null
|
||||
|
||||
# download data, generate manifests
|
||||
PYTHONPATH=.:$PYTHONPATH python data/aishell/aishell.py \
|
||||
--manifest_prefix='data/aishell/manifest' \
|
||||
--target_dir='~/.cache/paddle/dataset/speech/Aishell'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Prepare Aishell failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
# build vocabulary
|
||||
python tools/build_vocab.py \
|
||||
--count_threshold=0 \
|
||||
--vocab_path='data/aishell/vocab.txt' \
|
||||
--manifest_paths 'data/aishell/manifest.train' 'data/aishell/manifest.dev'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Build vocabulary failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
# compute mean and stddev for normalizer
|
||||
python tools/compute_mean_std.py \
|
||||
--manifest_path='data/aishell/manifest.train' \
|
||||
--num_samples=2000 \
|
||||
--specgram_type='linear' \
|
||||
--output_path='data/aishell/mean_std.npz'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Compute mean and stddev failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
echo "Aishell data preparation done."
|
||||
exit 0
|
@ -0,0 +1,46 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
cd ../.. > /dev/null
|
||||
|
||||
# download language model
|
||||
cd models/lm > /dev/null
|
||||
sh download_lm_ch.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
cd - > /dev/null
|
||||
|
||||
|
||||
# infer
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python -u infer.py \
|
||||
--num_samples=10 \
|
||||
--trainer_count=1 \
|
||||
--beam_size=300 \
|
||||
--num_proc_bsearch=8 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=1024 \
|
||||
--alpha=2.6 \
|
||||
--beta=5.0 \
|
||||
--cutoff_prob=0.99 \
|
||||
--cutoff_top_n=40 \
|
||||
--use_gru=True \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=False \
|
||||
--infer_manifest='data/aishell/manifest.test' \
|
||||
--mean_std_path='data/aishell/mean_std.npz' \
|
||||
--vocab_path='data/aishell/vocab.txt' \
|
||||
--model_path='checkpoints/aishell/params.latest.tar.gz' \
|
||||
--lang_model_path='models/lm/zh_giga.no_cna_cmn.prune01244.klm' \
|
||||
--decoding_method='ctc_beam_search' \
|
||||
--error_rate_type='cer' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in inference!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,55 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
cd ../.. > /dev/null
|
||||
|
||||
# download language model
|
||||
cd models/lm > /dev/null
|
||||
sh download_lm_ch.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
cd - > /dev/null
|
||||
|
||||
|
||||
# download well-trained model
|
||||
cd models/aishell > /dev/null
|
||||
sh download_model.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
cd - > /dev/null
|
||||
|
||||
|
||||
# infer
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python -u infer.py \
|
||||
--num_samples=10 \
|
||||
--trainer_count=1 \
|
||||
--beam_size=300 \
|
||||
--num_proc_bsearch=8 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=1024 \
|
||||
--alpha=2.6 \
|
||||
--beta=5.0 \
|
||||
--cutoff_prob=0.99 \
|
||||
--cutoff_top_n=40 \
|
||||
--use_gru=True \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=False \
|
||||
--infer_manifest='data/aishell/manifest.test' \
|
||||
--mean_std_path='models/aishell/mean_std.npz' \
|
||||
--vocab_path='models/aishell/vocab.txt' \
|
||||
--model_path='models/aishell/params.tar.gz' \
|
||||
--lang_model_path='models/lm/zh_giga.no_cna_cmn.prune01244.klm' \
|
||||
--decoding_method='ctc_beam_search' \
|
||||
--error_rate_type='cer' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in inference!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,47 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
cd ../.. > /dev/null
|
||||
|
||||
# download language model
|
||||
cd models/lm > /dev/null
|
||||
sh download_lm_ch.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
cd - > /dev/null
|
||||
|
||||
|
||||
# evaluate model
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
python -u test.py \
|
||||
--batch_size=128 \
|
||||
--trainer_count=8 \
|
||||
--beam_size=300 \
|
||||
--num_proc_bsearch=8 \
|
||||
--num_proc_data=8 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=1024 \
|
||||
--alpha=2.6 \
|
||||
--beta=5.0 \
|
||||
--cutoff_prob=0.99 \
|
||||
--cutoff_top_n=40 \
|
||||
--use_gru=True \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=False \
|
||||
--test_manifest='data/aishell/manifest.test' \
|
||||
--mean_std_path='data/aishell/mean_std.npz' \
|
||||
--vocab_path='data/aishell/vocab.txt' \
|
||||
--model_path='checkpoints/aishell/params.latest.tar.gz' \
|
||||
--lang_model_path='models/lm/zh_giga.no_cna_cmn.prune01244.klm' \
|
||||
--decoding_method='ctc_beam_search' \
|
||||
--error_rate_type='cer' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,56 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
cd ../.. > /dev/null
|
||||
|
||||
# download language model
|
||||
cd models/lm > /dev/null
|
||||
sh download_lm_ch.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
cd - > /dev/null
|
||||
|
||||
|
||||
# download well-trained model
|
||||
cd models/aishell > /dev/null
|
||||
sh download_model.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
cd - > /dev/null
|
||||
|
||||
|
||||
# evaluate model
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
python -u test.py \
|
||||
--batch_size=128 \
|
||||
--trainer_count=8 \
|
||||
--beam_size=300 \
|
||||
--num_proc_bsearch=8 \
|
||||
--num_proc_data=8 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=1024 \
|
||||
--alpha=2.6 \
|
||||
--beta=5.0 \
|
||||
--cutoff_prob=0.99 \
|
||||
--cutoff_top_n=40 \
|
||||
--use_gru=True \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=False \
|
||||
--test_manifest='data/aishell/manifest.test' \
|
||||
--mean_std_path='models/aishell/mean_std.npz' \
|
||||
--vocab_path='models/aishell/vocab.txt' \
|
||||
--model_path='models/aishell/params.tar.gz' \
|
||||
--lang_model_path='models/lm/zh_giga.no_cna_cmn.prune01244.klm' \
|
||||
--decoding_method='ctc_beam_search' \
|
||||
--error_rate_type='cer' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,41 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
cd ../.. > /dev/null
|
||||
|
||||
# train model
|
||||
# if you wish to resume from an exists model, uncomment --init_model_path
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
python -u train.py \
|
||||
--batch_size=64 \
|
||||
--trainer_count=8 \
|
||||
--num_passes=50 \
|
||||
--num_proc_data=16 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=1024 \
|
||||
--num_iter_print=100 \
|
||||
--learning_rate=5e-4 \
|
||||
--max_duration=27.0 \
|
||||
--min_duration=0.0 \
|
||||
--test_off=False \
|
||||
--use_sortagrad=True \
|
||||
--use_gru=True \
|
||||
--use_gpu=True \
|
||||
--is_local=True \
|
||||
--share_rnn_weights=False \
|
||||
--train_manifest='data/aishell/manifest.train' \
|
||||
--dev_manifest='data/aishell/manifest.dev' \
|
||||
--mean_std_path='data/aishell/mean_std.npz' \
|
||||
--vocab_path='data/aishell/vocab.txt' \
|
||||
--output_model_dir='./checkpoints/aishell' \
|
||||
--augment_conf_path='conf/augmentation.config' \
|
||||
--specgram_type='linear' \
|
||||
--shuffle_method='batch_shuffle_clipped'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in training!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,55 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
cd ../.. > /dev/null
|
||||
|
||||
# download language model
|
||||
cd models/lm > /dev/null
|
||||
sh download_lm_en.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
cd - > /dev/null
|
||||
|
||||
|
||||
# download well-trained model
|
||||
cd models/baidu_en8k > /dev/null
|
||||
sh download_model.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
cd - > /dev/null
|
||||
|
||||
|
||||
# infer
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python -u infer.py \
|
||||
--num_samples=10 \
|
||||
--trainer_count=1 \
|
||||
--beam_size=500 \
|
||||
--num_proc_bsearch=5 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=1024 \
|
||||
--alpha=1.4 \
|
||||
--beta=0.35 \
|
||||
--cutoff_prob=1.0 \
|
||||
--cutoff_top_n=40 \
|
||||
--use_gru=True \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=False \
|
||||
--infer_manifest='data/librispeech/manifest.test-clean' \
|
||||
--mean_std_path='models/baidu_en8k/mean_std.npz' \
|
||||
--vocab_path='models/baidu_en8k/vocab.txt' \
|
||||
--model_path='models/baidu_en8k/params.tar.gz' \
|
||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||
--decoding_method='ctc_beam_search' \
|
||||
--error_rate_type='wer' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in inference!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,55 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
cd ../.. > /dev/null
|
||||
|
||||
# download language model
|
||||
cd models/lm > /dev/null
|
||||
sh download_lm_en.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
cd - > /dev/null
|
||||
|
||||
|
||||
# download well-trained model
|
||||
cd models/baidu_en8k > /dev/null
|
||||
sh download_model.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
cd - > /dev/null
|
||||
|
||||
|
||||
# evaluate model
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||
python -u test.py \
|
||||
--batch_size=128 \
|
||||
--trainer_count=4 \
|
||||
--beam_size=500 \
|
||||
--num_proc_bsearch=8 \
|
||||
--num_proc_data=8 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=1024 \
|
||||
--alpha=1.4 \
|
||||
--beta=0.35 \
|
||||
--cutoff_prob=1.0 \
|
||||
--cutoff_top_n=40 \
|
||||
--use_gru=True \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=False \
|
||||
--test_manifest='data/librispeech/manifest.test-clean' \
|
||||
--mean_std_path='models/baidu_en8k/mean_std.npz' \
|
||||
--vocab_path='models/baidu_en8k/vocab.txt' \
|
||||
--model_path='models/baidu_en8k/params.tar.gz' \
|
||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||
--decoding_method='ctc_beam_search' \
|
||||
--error_rate_type='wer' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
exit 0
|
@ -0,0 +1,17 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
cd ../.. > /dev/null
|
||||
|
||||
# start demo client
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python -u deploy/demo_client.py \
|
||||
--host_ip='localhost' \
|
||||
--host_port=8086 \
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in starting demo client!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,54 @@
|
||||
#! /usr/bin/env bash
|
||||
# TODO: replace the model with a mandarin model
|
||||
|
||||
cd ../.. > /dev/null
|
||||
|
||||
# download language model
|
||||
cd models/lm > /dev/null
|
||||
sh download_lm_en.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
cd - > /dev/null
|
||||
|
||||
|
||||
# download well-trained model
|
||||
cd models/baidu_en8k > /dev/null
|
||||
sh download_model.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
cd - > /dev/null
|
||||
|
||||
|
||||
# start demo server
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python -u deploy/demo_server.py \
|
||||
--host_ip='localhost' \
|
||||
--host_port=8086 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=1024 \
|
||||
--alpha=1.15 \
|
||||
--beta=0.15 \
|
||||
--cutoff_prob=1.0 \
|
||||
--cutoff_top_n=40 \
|
||||
--use_gru=True \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=False \
|
||||
--speech_save_dir='demo_cache' \
|
||||
--warmup_manifest='data/tiny/manifest.test-clean' \
|
||||
--mean_std_path='models/baidu_en8k/mean_std.npz' \
|
||||
--vocab_path='models/baidu_en8k/vocab.txt' \
|
||||
--model_path='models/baidu_en8k/params.tar.gz' \
|
||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||
--decoding_method='ctc_beam_search' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in starting demo server!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,45 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
cd ../.. > /dev/null
|
||||
|
||||
# download data, generate manifests
|
||||
PYTHONPATH=.:$PYTHONPATH python data/librispeech/librispeech.py \
|
||||
--manifest_prefix='data/librispeech/manifest' \
|
||||
--target_dir='~/.cache/paddle/dataset/speech/libri' \
|
||||
--full_download='True'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Prepare LibriSpeech failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cat data/librispeech/manifest.train-* | shuf > data/librispeech/manifest.train
|
||||
|
||||
|
||||
# build vocabulary
|
||||
python tools/build_vocab.py \
|
||||
--count_threshold=0 \
|
||||
--vocab_path='data/librispeech/vocab.txt' \
|
||||
--manifest_paths='data/librispeech/manifest.train'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Build vocabulary failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
# compute mean and stddev for normalizer
|
||||
python tools/compute_mean_std.py \
|
||||
--manifest_path='data/librispeech/manifest.train' \
|
||||
--num_samples=2000 \
|
||||
--specgram_type='linear' \
|
||||
--output_path='data/librispeech/mean_std.npz'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Compute mean and stddev failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
echo "LibriSpeech Data preparation done."
|
||||
exit 0
|
@ -0,0 +1,46 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
cd ../.. > /dev/null
|
||||
|
||||
# download language model
|
||||
cd models/lm > /dev/null
|
||||
sh download_lm_en.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
cd - > /dev/null
|
||||
|
||||
|
||||
# infer
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python -u infer.py \
|
||||
--num_samples=10 \
|
||||
--trainer_count=1 \
|
||||
--beam_size=500 \
|
||||
--num_proc_bsearch=8 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--alpha=2.5 \
|
||||
--beta=0.3 \
|
||||
--cutoff_prob=1.0 \
|
||||
--cutoff_top_n=40 \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=True \
|
||||
--infer_manifest='data/librispeech/manifest.test-clean' \
|
||||
--mean_std_path='data/librispeech/mean_std.npz' \
|
||||
--vocab_path='data/librispeech/vocab.txt' \
|
||||
--model_path='checkpoints/libri/params.latest.tar.gz' \
|
||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||
--decoding_method='ctc_beam_search' \
|
||||
--error_rate_type='wer' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in inference!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,55 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
cd ../.. > /dev/null
|
||||
|
||||
# download language model
|
||||
cd models/lm > /dev/null
|
||||
sh download_lm_en.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
cd - > /dev/null
|
||||
|
||||
|
||||
# download well-trained model
|
||||
cd models/librispeech > /dev/null
|
||||
sh download_model.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
cd - > /dev/null
|
||||
|
||||
|
||||
# infer
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python -u infer.py \
|
||||
--num_samples=10 \
|
||||
--trainer_count=1 \
|
||||
--beam_size=500 \
|
||||
--num_proc_bsearch=8 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--alpha=2.5 \
|
||||
--beta=0.3 \
|
||||
--cutoff_prob=1.0 \
|
||||
--cutoff_top_n=40 \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=True \
|
||||
--infer_manifest='data/librispeech/manifest.test-clean' \
|
||||
--mean_std_path='models/librispeech/mean_std.npz' \
|
||||
--vocab_path='models/librispeech/vocab.txt' \
|
||||
--model_path='models/librispeech/params.tar.gz' \
|
||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||
--decoding_method='ctc_beam_search' \
|
||||
--error_rate_type='wer' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in inference!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,47 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
cd ../.. > /dev/null
|
||||
|
||||
# download language model
|
||||
cd models/lm > /dev/null
|
||||
sh download_lm_en.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
cd - > /dev/null
|
||||
|
||||
|
||||
# evaluate model
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
python -u test.py \
|
||||
--batch_size=128 \
|
||||
--trainer_count=8 \
|
||||
--beam_size=500 \
|
||||
--num_proc_bsearch=8 \
|
||||
--num_proc_data=8 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--alpha=2.5 \
|
||||
--beta=0.3 \
|
||||
--cutoff_prob=1.0 \
|
||||
--cutoff_top_n=40 \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=True \
|
||||
--test_manifest='data/librispeech/manifest.test-clean' \
|
||||
--mean_std_path='data/librispeech/mean_std.npz' \
|
||||
--vocab_path='data/librispeech/vocab.txt' \
|
||||
--model_path='checkpoints/libri/params.latest.tar.gz' \
|
||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||
--decoding_method='ctc_beam_search' \
|
||||
--error_rate_type='wer' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,56 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
cd ../.. > /dev/null
|
||||
|
||||
# download language model
|
||||
cd models/lm > /dev/null
|
||||
sh download_lm_en.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
cd - > /dev/null
|
||||
|
||||
|
||||
# download well-trained model
|
||||
cd models/librispeech > /dev/null
|
||||
sh download_model.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
cd - > /dev/null
|
||||
|
||||
|
||||
# evaluate model
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
python -u test.py \
|
||||
--batch_size=128 \
|
||||
--trainer_count=8 \
|
||||
--beam_size=500 \
|
||||
--num_proc_bsearch=8 \
|
||||
--num_proc_data=8 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--alpha=2.5 \
|
||||
--beta=0.3 \
|
||||
--cutoff_prob=1.0 \
|
||||
--cutoff_top_n=40 \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=True \
|
||||
--test_manifest='data/librispeech/manifest.test-clean' \
|
||||
--mean_std_path='models/librispeech/mean_std.npz' \
|
||||
--vocab_path='models/librispeech/vocab.txt' \
|
||||
--model_path='models/librispeech/params.tar.gz' \
|
||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||
--decoding_method='ctc_beam_search' \
|
||||
--error_rate_type='wer' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,41 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
cd ../.. > /dev/null
|
||||
|
||||
# train model
|
||||
# if you wish to resume from an exists model, uncomment --init_model_path
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
python -u train.py \
|
||||
--batch_size=160 \
|
||||
--trainer_count=8 \
|
||||
--num_passes=50 \
|
||||
--num_proc_data=16 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--num_iter_print=100 \
|
||||
--learning_rate=5e-4 \
|
||||
--max_duration=27.0 \
|
||||
--min_duration=0.0 \
|
||||
--test_off=False \
|
||||
--use_sortagrad=True \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--is_local=True \
|
||||
--share_rnn_weights=True \
|
||||
--train_manifest='data/librispeech/manifest.train' \
|
||||
--dev_manifest='data/librispeech/manifest.dev-clean' \
|
||||
--mean_std_path='data/librispeech/mean_std.npz' \
|
||||
--vocab_path='data/librispeech/vocab.txt' \
|
||||
--output_model_dir='./checkpoints/libri' \
|
||||
--augment_conf_path='conf/augmentation.config' \
|
||||
--specgram_type='linear' \
|
||||
--shuffle_method='batch_shuffle_clipped'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in training!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,41 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
cd ../.. > /dev/null
|
||||
|
||||
# grid-search for hyper-parameters in language model
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||
python -u tools/tune.py \
|
||||
--num_batches=-1 \
|
||||
--batch_size=128 \
|
||||
--trainer_count=4 \
|
||||
--beam_size=500 \
|
||||
--num_proc_bsearch=12 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--num_alphas=45 \
|
||||
--num_betas=8 \
|
||||
--alpha_from=1.0 \
|
||||
--alpha_to=3.2 \
|
||||
--beta_from=0.1 \
|
||||
--beta_to=0.45 \
|
||||
--cutoff_prob=1.0 \
|
||||
--cutoff_top_n=40 \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=True \
|
||||
--tune_manifest='data/librispeech/manifest.dev-clean' \
|
||||
--mean_std_path='data/librispeech/mean_std.npz' \
|
||||
--vocab_path='models/librispeech/vocab.txt' \
|
||||
--model_path='models/librispeech/params.tar.gz' \
|
||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||
--error_rate_type='wer' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in tuning!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,51 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
cd ../.. > /dev/null
|
||||
|
||||
# prepare folder
|
||||
if [ ! -e data/tiny ]; then
|
||||
mkdir data/tiny
|
||||
fi
|
||||
|
||||
|
||||
# download data, generate manifests
|
||||
PYTHONPATH=.:$PYTHONPATH python data/librispeech/librispeech.py \
|
||||
--manifest_prefix='data/tiny/manifest' \
|
||||
--target_dir='~/.cache/paddle/dataset/speech/libri' \
|
||||
--full_download='False'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Prepare LibriSpeech failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
head -n 64 data/tiny/manifest.dev-clean > data/tiny/manifest.tiny
|
||||
|
||||
|
||||
# build vocabulary
|
||||
python tools/build_vocab.py \
|
||||
--count_threshold=0 \
|
||||
--vocab_path='data/tiny/vocab.txt' \
|
||||
--manifest_paths='data/tiny/manifest.dev-clean'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Build vocabulary failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
# compute mean and stddev for normalizer
|
||||
python tools/compute_mean_std.py \
|
||||
--manifest_path='data/tiny/manifest.tiny' \
|
||||
--num_samples=64 \
|
||||
--specgram_type='linear' \
|
||||
--output_path='data/tiny/mean_std.npz'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Compute mean and stddev failed. Terminated."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
echo "Tiny data preparation done."
|
||||
exit 0
|
@ -0,0 +1,46 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
cd ../.. > /dev/null
|
||||
|
||||
# download language model
|
||||
cd models/lm > /dev/null
|
||||
sh download_lm_en.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
cd - > /dev/null
|
||||
|
||||
|
||||
# infer
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python -u infer.py \
|
||||
--num_samples=10 \
|
||||
--trainer_count=1 \
|
||||
--beam_size=500 \
|
||||
--num_proc_bsearch=8 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--alpha=2.5 \
|
||||
--beta=0.3 \
|
||||
--cutoff_prob=1.0 \
|
||||
--cutoff_top_n=40 \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=True \
|
||||
--infer_manifest='data/tiny/manifest.tiny' \
|
||||
--mean_std_path='data/tiny/mean_std.npz' \
|
||||
--vocab_path='data/tiny/vocab.txt' \
|
||||
--model_path='checkpoints/tiny/params.pass-19.tar.gz' \
|
||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||
--decoding_method='ctc_beam_search' \
|
||||
--error_rate_type='wer' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in inference!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,55 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
cd ../.. > /dev/null
|
||||
|
||||
# download language model
|
||||
cd models/lm > /dev/null
|
||||
sh download_lm_en.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
cd - > /dev/null
|
||||
|
||||
|
||||
# download well-trained model
|
||||
cd models/librispeech > /dev/null
|
||||
sh download_model.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
cd - > /dev/null
|
||||
|
||||
|
||||
# infer
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
python -u infer.py \
|
||||
--num_samples=10 \
|
||||
--trainer_count=1 \
|
||||
--beam_size=500 \
|
||||
--num_proc_bsearch=8 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--alpha=2.5 \
|
||||
--beta=0.3 \
|
||||
--cutoff_prob=1.0 \
|
||||
--cutoff_top_n=40 \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=True \
|
||||
--infer_manifest='data/tiny/manifest.test-clean' \
|
||||
--mean_std_path='models/librispeech/mean_std.npz' \
|
||||
--vocab_path='models/librispeech/vocab.txt' \
|
||||
--model_path='models/librispeech/params.tar.gz' \
|
||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||
--decoding_method='ctc_beam_search' \
|
||||
--error_rate_type='wer' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in inference!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,47 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
cd ../.. > /dev/null
|
||||
|
||||
# download language model
|
||||
cd models/lm > /dev/null
|
||||
sh download_lm_en.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
cd - > /dev/null
|
||||
|
||||
|
||||
# evaluate model
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
python -u test.py \
|
||||
--batch_size=16 \
|
||||
--trainer_count=8 \
|
||||
--beam_size=500 \
|
||||
--num_proc_bsearch=8 \
|
||||
--num_proc_data=8 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--alpha=2.5 \
|
||||
--beta=0.3 \
|
||||
--cutoff_prob=1.0 \
|
||||
--cutoff_top_n=40 \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=True \
|
||||
--test_manifest='data/tiny/manifest.tiny' \
|
||||
--mean_std_path='data/tiny/mean_std.npz' \
|
||||
--vocab_path='data/tiny/vocab.txt' \
|
||||
--model_path='checkpoints/tiny/params.pass-19.tar.gz' \
|
||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||
--decoding_method='ctc_beam_search' \
|
||||
--error_rate_type='wer' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,56 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
cd ../.. > /dev/null
|
||||
|
||||
# download language model
|
||||
cd models/lm > /dev/null
|
||||
sh download_lm_en.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
cd - > /dev/null
|
||||
|
||||
|
||||
# download well-trained model
|
||||
cd models/librispeech > /dev/null
|
||||
sh download_model.sh
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
cd - > /dev/null
|
||||
|
||||
|
||||
# evaluate model
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
python -u test.py \
|
||||
--batch_size=128 \
|
||||
--trainer_count=8 \
|
||||
--beam_size=500 \
|
||||
--num_proc_bsearch=8 \
|
||||
--num_proc_data=8 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--alpha=2.5 \
|
||||
--beta=0.3 \
|
||||
--cutoff_prob=1.0 \
|
||||
--cutoff_top_n=40 \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=True \
|
||||
--test_manifest='data/tiny/manifest.test-clean' \
|
||||
--mean_std_path='models/librispeech/mean_std.npz' \
|
||||
--vocab_path='models/librispeech/vocab.txt' \
|
||||
--model_path='models/librispeech/params.tar.gz' \
|
||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||
--decoding_method='ctc_beam_search' \
|
||||
--error_rate_type='wer' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in evaluation!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,41 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
cd ../.. > /dev/null
|
||||
|
||||
# train model
|
||||
# if you wish to resume from an exists model, uncomment --init_model_path
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||
python -u train.py \
|
||||
--batch_size=16 \
|
||||
--trainer_count=4 \
|
||||
--num_passes=20 \
|
||||
--num_proc_data=1 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--num_iter_print=100 \
|
||||
--learning_rate=1e-5 \
|
||||
--max_duration=27.0 \
|
||||
--min_duration=0.0 \
|
||||
--test_off=False \
|
||||
--use_sortagrad=True \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--is_local=True \
|
||||
--share_rnn_weights=True \
|
||||
--train_manifest='data/tiny/manifest.tiny' \
|
||||
--dev_manifest='data/tiny/manifest.tiny' \
|
||||
--mean_std_path='data/tiny/mean_std.npz' \
|
||||
--vocab_path='data/tiny/vocab.txt' \
|
||||
--output_model_dir='./checkpoints/tiny' \
|
||||
--augment_conf_path='conf/augmentation.config' \
|
||||
--specgram_type='linear' \
|
||||
--shuffle_method='batch_shuffle_clipped'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Fail in training!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,41 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
cd ../.. > /dev/null
|
||||
|
||||
# grid-search for hyper-parameters in language model
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
python -u tools/tune.py \
|
||||
--num_batches=1 \
|
||||
--batch_size=24 \
|
||||
--trainer_count=8 \
|
||||
--beam_size=500 \
|
||||
--num_proc_bsearch=12 \
|
||||
--num_conv_layers=2 \
|
||||
--num_rnn_layers=3 \
|
||||
--rnn_layer_size=2048 \
|
||||
--num_alphas=45 \
|
||||
--num_betas=8 \
|
||||
--alpha_from=1.0 \
|
||||
--alpha_to=3.2 \
|
||||
--beta_from=0.1 \
|
||||
--beta_to=0.45 \
|
||||
--cutoff_prob=1.0 \
|
||||
--cutoff_top_n=40 \
|
||||
--use_gru=False \
|
||||
--use_gpu=True \
|
||||
--share_rnn_weights=True \
|
||||
--tune_manifest='data/tiny/manifest.tiny' \
|
||||
--mean_std_path='data/tiny/mean_std.npz' \
|
||||
--vocab_path='data/tiny/vocab.txt' \
|
||||
--model_path='checkpoints/params.pass-9.tar.gz' \
|
||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||
--error_rate_type='wer' \
|
||||
--specgram_type='linear'
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed in tuning!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,135 @@
|
||||
"""Inferer for DeepSpeech2 model."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import functools
|
||||
import paddle.v2 as paddle
|
||||
from data_utils.data import DataGenerator
|
||||
from model_utils.model import DeepSpeech2Model
|
||||
from utils.error_rate import wer, cer
|
||||
from utils.utility import add_arguments, print_arguments
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
add_arg = functools.partial(add_arguments, argparser=parser)
|
||||
# yapf: disable
|
||||
add_arg('num_samples', int, 10, "# of samples to infer.")
|
||||
add_arg('trainer_count', int, 8, "# of Trainers (CPUs or GPUs).")
|
||||
add_arg('beam_size', int, 500, "Beam search width.")
|
||||
add_arg('num_proc_bsearch', int, 8, "# of CPUs for beam search.")
|
||||
add_arg('num_conv_layers', int, 2, "# of convolution layers.")
|
||||
add_arg('num_rnn_layers', int, 3, "# of recurrent layers.")
|
||||
add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.")
|
||||
add_arg('alpha', float, 2.5, "Coef of LM for beam search.")
|
||||
add_arg('beta', float, 0.3, "Coef of WC for beam search.")
|
||||
add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.")
|
||||
add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.")
|
||||
add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.")
|
||||
add_arg('use_gpu', bool, True, "Use GPU or not.")
|
||||
add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across "
|
||||
"bi-directional RNNs. Not for GRU.")
|
||||
add_arg('infer_manifest', str,
|
||||
'data/librispeech/manifest.dev-clean',
|
||||
"Filepath of manifest to infer.")
|
||||
add_arg('mean_std_path', str,
|
||||
'data/librispeech/mean_std.npz',
|
||||
"Filepath of normalizer's mean & std.")
|
||||
add_arg('vocab_path', str,
|
||||
'data/librispeech/vocab.txt',
|
||||
"Filepath of vocabulary.")
|
||||
add_arg('lang_model_path', str,
|
||||
'models/lm/common_crawl_00.prune01111.trie.klm',
|
||||
"Filepath for language model.")
|
||||
add_arg('model_path', str,
|
||||
'./checkpoints/libri/params.latest.tar.gz',
|
||||
"If None, the training starts from scratch, "
|
||||
"otherwise, it resumes from the pre-trained model.")
|
||||
add_arg('decoding_method', str,
|
||||
'ctc_beam_search',
|
||||
"Decoding method. Options: ctc_beam_search, ctc_greedy",
|
||||
choices = ['ctc_beam_search', 'ctc_greedy'])
|
||||
add_arg('error_rate_type', str,
|
||||
'wer',
|
||||
"Error rate type for evaluation.",
|
||||
choices=['wer', 'cer'])
|
||||
add_arg('specgram_type', str,
|
||||
'linear',
|
||||
"Audio feature type. Options: linear, mfcc.",
|
||||
choices=['linear', 'mfcc'])
|
||||
# yapf: disable
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def infer():
|
||||
"""Inference for DeepSpeech2."""
|
||||
data_generator = DataGenerator(
|
||||
vocab_filepath=args.vocab_path,
|
||||
mean_std_filepath=args.mean_std_path,
|
||||
augmentation_config='{}',
|
||||
specgram_type=args.specgram_type,
|
||||
num_threads=1,
|
||||
keep_transcription_text=True)
|
||||
batch_reader = data_generator.batch_reader_creator(
|
||||
manifest_path=args.infer_manifest,
|
||||
batch_size=args.num_samples,
|
||||
min_batch_size=1,
|
||||
sortagrad=False,
|
||||
shuffle_method=None)
|
||||
infer_data = batch_reader().next()
|
||||
|
||||
ds2_model = DeepSpeech2Model(
|
||||
vocab_size=data_generator.vocab_size,
|
||||
num_conv_layers=args.num_conv_layers,
|
||||
num_rnn_layers=args.num_rnn_layers,
|
||||
rnn_layer_size=args.rnn_layer_size,
|
||||
use_gru=args.use_gru,
|
||||
pretrained_model_path=args.model_path,
|
||||
share_rnn_weights=args.share_rnn_weights)
|
||||
|
||||
# decoders only accept string encoded in utf-8
|
||||
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
|
||||
|
||||
if args.decoding_method == "ctc_greedy":
|
||||
ds2_model.logger.info("start inference ...")
|
||||
probs_split = ds2_model.infer_batch_probs(infer_data=infer_data,
|
||||
feeding_dict=data_generator.feeding)
|
||||
result_transcripts = ds2_model.decode_batch_greedy(
|
||||
probs_split=probs_split,
|
||||
vocab_list=vocab_list)
|
||||
else:
|
||||
ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path,
|
||||
vocab_list)
|
||||
ds2_model.logger.info("start inference ...")
|
||||
probs_split = ds2_model.infer_batch_probs(infer_data=infer_data,
|
||||
feeding_dict=data_generator.feeding)
|
||||
result_transcripts = ds2_model.decode_batch_beam_search(
|
||||
probs_split=probs_split,
|
||||
beam_alpha=args.alpha,
|
||||
beam_beta=args.beta,
|
||||
beam_size=args.beam_size,
|
||||
cutoff_prob=args.cutoff_prob,
|
||||
cutoff_top_n=args.cutoff_top_n,
|
||||
vocab_list=vocab_list,
|
||||
num_processes=args.num_proc_bsearch)
|
||||
|
||||
error_rate_func = cer if args.error_rate_type == 'cer' else wer
|
||||
target_transcripts = [data[1] for data in infer_data]
|
||||
for target, result in zip(target_transcripts, result_transcripts):
|
||||
print("\nTarget Transcription: %s\nOutput Transcription: %s" %
|
||||
(target, result))
|
||||
print("Current error rate [%s] = %f" %
|
||||
(args.error_rate_type, error_rate_func(target, result)))
|
||||
|
||||
ds2_model.logger.info("finish inference")
|
||||
|
||||
def main():
|
||||
print_arguments(args)
|
||||
paddle.init(use_gpu=args.use_gpu,
|
||||
rnn_use_batch=True,
|
||||
trainer_count=args.trainer_count)
|
||||
infer()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,442 @@
|
||||
"""Contains DeepSpeech2 model."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import gzip
|
||||
import copy
|
||||
import inspect
|
||||
from distutils.dir_util import mkpath
|
||||
import paddle.v2 as paddle
|
||||
from decoders.swig_wrapper import Scorer
|
||||
from decoders.swig_wrapper import ctc_greedy_decoder
|
||||
from decoders.swig_wrapper import ctc_beam_search_decoder_batch
|
||||
from model_utils.network import deep_speech_v2_network
|
||||
|
||||
logging.basicConfig(
|
||||
format='[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s')
|
||||
|
||||
|
||||
class DeepSpeech2Model(object):
|
||||
"""DeepSpeech2Model class.
|
||||
|
||||
:param vocab_size: Decoding vocabulary size.
|
||||
:type vocab_size: int
|
||||
:param num_conv_layers: Number of stacking convolution layers.
|
||||
:type num_conv_layers: int
|
||||
:param num_rnn_layers: Number of stacking RNN layers.
|
||||
:type num_rnn_layers: int
|
||||
:param rnn_layer_size: RNN layer size (number of RNN cells).
|
||||
:type rnn_layer_size: int
|
||||
:param pretrained_model_path: Pretrained model path. If None, will train
|
||||
from stratch.
|
||||
:type pretrained_model_path: basestring|None
|
||||
:param share_rnn_weights: Whether to share input-hidden weights between
|
||||
forward and backward directional RNNs.Notice that
|
||||
for GRU, weight sharing is not supported.
|
||||
:type share_rnn_weights: bool
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_size, num_conv_layers, num_rnn_layers,
|
||||
rnn_layer_size, use_gru, pretrained_model_path,
|
||||
share_rnn_weights):
|
||||
self._create_network(vocab_size, num_conv_layers, num_rnn_layers,
|
||||
rnn_layer_size, use_gru, share_rnn_weights)
|
||||
self._create_parameters(pretrained_model_path)
|
||||
self._inferer = None
|
||||
self._loss_inferer = None
|
||||
self._ext_scorer = None
|
||||
self._num_conv_layers = num_conv_layers
|
||||
self.logger = logging.getLogger("")
|
||||
self.logger.setLevel(level=logging.INFO)
|
||||
|
||||
def train(self,
|
||||
train_batch_reader,
|
||||
dev_batch_reader,
|
||||
feeding_dict,
|
||||
learning_rate,
|
||||
gradient_clipping,
|
||||
num_passes,
|
||||
output_model_dir,
|
||||
is_local=True,
|
||||
num_iterations_print=100,
|
||||
test_off=False):
|
||||
"""Train the model.
|
||||
|
||||
:param train_batch_reader: Train data reader.
|
||||
:type train_batch_reader: callable
|
||||
:param dev_batch_reader: Validation data reader.
|
||||
:type dev_batch_reader: callable
|
||||
:param feeding_dict: Feeding is a map of field name and tuple index
|
||||
of the data that reader returns.
|
||||
:type feeding_dict: dict|list
|
||||
:param learning_rate: Learning rate for ADAM optimizer.
|
||||
:type learning_rate: float
|
||||
:param gradient_clipping: Gradient clipping threshold.
|
||||
:type gradient_clipping: float
|
||||
:param num_passes: Number of training epochs.
|
||||
:type num_passes: int
|
||||
:param num_iterations_print: Number of training iterations for printing
|
||||
a training loss.
|
||||
:type rnn_iteratons_print: int
|
||||
:param is_local: Set to False if running with pserver with multi-nodes.
|
||||
:type is_local: bool
|
||||
:param output_model_dir: Directory for saving the model (every pass).
|
||||
:type output_model_dir: basestring
|
||||
:param test_off: Turn off testing.
|
||||
:type test_off: bool
|
||||
"""
|
||||
# prepare model output directory
|
||||
if not os.path.exists(output_model_dir):
|
||||
mkpath(output_model_dir)
|
||||
|
||||
# adapt the feeding dict and reader according to the network
|
||||
adapted_feeding_dict = self._adapt_feeding_dict(feeding_dict)
|
||||
adapted_train_batch_reader = self._adapt_data(train_batch_reader)
|
||||
adapted_dev_batch_reader = self._adapt_data(dev_batch_reader)
|
||||
|
||||
# prepare optimizer and trainer
|
||||
optimizer = paddle.optimizer.Adam(
|
||||
learning_rate=learning_rate,
|
||||
gradient_clipping_threshold=gradient_clipping)
|
||||
trainer = paddle.trainer.SGD(
|
||||
cost=self._loss,
|
||||
parameters=self._parameters,
|
||||
update_equation=optimizer,
|
||||
is_local=is_local)
|
||||
|
||||
# create event handler
|
||||
def event_handler(event):
|
||||
global start_time, cost_sum, cost_counter
|
||||
if isinstance(event, paddle.event.EndIteration):
|
||||
cost_sum += event.cost
|
||||
cost_counter += 1
|
||||
if (event.batch_id + 1) % num_iterations_print == 0:
|
||||
output_model_path = os.path.join(output_model_dir,
|
||||
"params.latest.tar.gz")
|
||||
with gzip.open(output_model_path, 'w') as f:
|
||||
trainer.save_parameter_to_tar(f)
|
||||
print("\nPass: %d, Batch: %d, TrainCost: %f" %
|
||||
(event.pass_id, event.batch_id + 1,
|
||||
cost_sum / cost_counter))
|
||||
cost_sum, cost_counter = 0.0, 0
|
||||
else:
|
||||
sys.stdout.write('.')
|
||||
sys.stdout.flush()
|
||||
if isinstance(event, paddle.event.BeginPass):
|
||||
start_time = time.time()
|
||||
cost_sum, cost_counter = 0.0, 0
|
||||
if isinstance(event, paddle.event.EndPass):
|
||||
if test_off:
|
||||
print("\n------- Time: %d sec, Pass: %d" %
|
||||
(time.time() - start_time, event.pass_id))
|
||||
else:
|
||||
result = trainer.test(
|
||||
reader=adapted_dev_batch_reader,
|
||||
feeding=adapted_feeding_dict)
|
||||
print(
|
||||
"\n------- Time: %d sec, Pass: %d, "
|
||||
"ValidationCost: %s" %
|
||||
(time.time() - start_time, event.pass_id, result.cost))
|
||||
output_model_path = os.path.join(
|
||||
output_model_dir, "params.pass-%d.tar.gz" % event.pass_id)
|
||||
with gzip.open(output_model_path, 'w') as f:
|
||||
trainer.save_parameter_to_tar(f)
|
||||
|
||||
# run train
|
||||
trainer.train(
|
||||
reader=adapted_train_batch_reader,
|
||||
event_handler=event_handler,
|
||||
num_passes=num_passes,
|
||||
feeding=adapted_feeding_dict)
|
||||
|
||||
# TODO(@pkuyym) merge this function into infer_batch
|
||||
def infer_loss_batch(self, infer_data):
|
||||
"""Model inference. Infer the ctc loss for a batch of speech
|
||||
utterances.
|
||||
|
||||
:param infer_data: List of utterances to infer, with each utterance a
|
||||
tuple of audio features and transcription text (empty
|
||||
string).
|
||||
:type infer_data: list
|
||||
:return: List of ctc loss.
|
||||
:rtype: List of float
|
||||
"""
|
||||
# define inferer
|
||||
if self._loss_inferer == None:
|
||||
self._loss_inferer = paddle.inference.Inference(
|
||||
output_layer=self._loss, parameters=self._parameters)
|
||||
# run inference
|
||||
return self._loss_inferer.infer(input=infer_data)
|
||||
|
||||
def infer_batch_probs(self, infer_data, feeding_dict):
|
||||
"""Infer the prob matrices for a batch of speech utterances.
|
||||
|
||||
:param infer_data: List of utterances to infer, with each utterance
|
||||
consisting of a tuple of audio features and
|
||||
transcription text (empty string).
|
||||
:type infer_data: list
|
||||
:param feeding_dict: Feeding is a map of field name and tuple index
|
||||
of the data that reader returns.
|
||||
:type feeding_dict: dict|list
|
||||
:return: List of 2-D probability matrix, and each consists of prob
|
||||
vectors for one speech utterancce.
|
||||
:rtype: List of matrix
|
||||
"""
|
||||
# define inferer
|
||||
if self._inferer == None:
|
||||
self._inferer = paddle.inference.Inference(
|
||||
output_layer=self._log_probs, parameters=self._parameters)
|
||||
adapted_feeding_dict = self._adapt_feeding_dict(feeding_dict)
|
||||
adapted_infer_data = self._adapt_data(infer_data)
|
||||
# run inference
|
||||
infer_results = self._inferer.infer(
|
||||
input=adapted_infer_data, feeding=adapted_feeding_dict)
|
||||
start_pos = [0] * (len(adapted_infer_data) + 1)
|
||||
for i in xrange(len(adapted_infer_data)):
|
||||
start_pos[i + 1] = start_pos[i] + adapted_infer_data[i][3][0]
|
||||
probs_split = [
|
||||
infer_results[start_pos[i]:start_pos[i + 1]]
|
||||
for i in xrange(0, len(adapted_infer_data))
|
||||
]
|
||||
return probs_split
|
||||
|
||||
def decode_batch_greedy(self, probs_split, vocab_list):
|
||||
"""Decode by best path for a batch of probs matrix input.
|
||||
|
||||
:param probs_split: List of 2-D probability matrix, and each consists
|
||||
of prob vectors for one speech utterancce.
|
||||
:param probs_split: List of matrix
|
||||
:param vocab_list: List of tokens in the vocabulary, for decoding.
|
||||
:type vocab_list: list
|
||||
:return: List of transcription texts.
|
||||
:rtype: List of basestring
|
||||
"""
|
||||
results = []
|
||||
for i, probs in enumerate(probs_split):
|
||||
output_transcription = ctc_greedy_decoder(
|
||||
probs_seq=probs, vocabulary=vocab_list)
|
||||
results.append(output_transcription)
|
||||
return results
|
||||
|
||||
def init_ext_scorer(self, beam_alpha, beam_beta, language_model_path,
|
||||
vocab_list):
|
||||
"""Initialize the external scorer.
|
||||
|
||||
:param beam_alpha: Parameter associated with language model.
|
||||
:type beam_alpha: float
|
||||
:param beam_beta: Parameter associated with word count.
|
||||
:type beam_beta: float
|
||||
:param language_model_path: Filepath for language model. If it is
|
||||
empty, the external scorer will be set to
|
||||
None, and the decoding method will be pure
|
||||
beam search without scorer.
|
||||
:type language_model_path: basestring|None
|
||||
:param vocab_list: List of tokens in the vocabulary, for decoding.
|
||||
:type vocab_list: list
|
||||
"""
|
||||
if language_model_path != '':
|
||||
self.logger.info("begin to initialize the external scorer "
|
||||
"for decoding")
|
||||
self._ext_scorer = Scorer(beam_alpha, beam_beta,
|
||||
language_model_path, vocab_list)
|
||||
lm_char_based = self._ext_scorer.is_character_based()
|
||||
lm_max_order = self._ext_scorer.get_max_order()
|
||||
lm_dict_size = self._ext_scorer.get_dict_size()
|
||||
self.logger.info("language model: "
|
||||
"is_character_based = %d," % lm_char_based +
|
||||
" max_order = %d," % lm_max_order +
|
||||
" dict_size = %d" % lm_dict_size)
|
||||
self.logger.info("end initializing scorer")
|
||||
else:
|
||||
self._ext_scorer = None
|
||||
self.logger.info("no language model provided, "
|
||||
"decoding by pure beam search without scorer.")
|
||||
|
||||
def decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta,
|
||||
beam_size, cutoff_prob, cutoff_top_n,
|
||||
vocab_list, num_processes):
|
||||
"""Decode by beam search for a batch of probs matrix input.
|
||||
|
||||
:param probs_split: List of 2-D probability matrix, and each consists
|
||||
of prob vectors for one speech utterancce.
|
||||
:param probs_split: List of matrix
|
||||
:param beam_alpha: Parameter associated with language model.
|
||||
:type beam_alpha: float
|
||||
:param beam_beta: Parameter associated with word count.
|
||||
:type beam_beta: float
|
||||
:param beam_size: Width for Beam search.
|
||||
:type beam_size: int
|
||||
:param cutoff_prob: Cutoff probability in pruning,
|
||||
default 1.0, no pruning.
|
||||
:type cutoff_prob: float
|
||||
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
|
||||
characters with highest probs in vocabulary will be
|
||||
used in beam search, default 40.
|
||||
:type cutoff_top_n: int
|
||||
:param vocab_list: List of tokens in the vocabulary, for decoding.
|
||||
:type vocab_list: list
|
||||
:param num_processes: Number of processes (CPU) for decoder.
|
||||
:type num_processes: int
|
||||
:return: List of transcription texts.
|
||||
:rtype: List of basestring
|
||||
"""
|
||||
if self._ext_scorer != None:
|
||||
self._ext_scorer.reset_params(beam_alpha, beam_beta)
|
||||
# beam search decode
|
||||
num_processes = min(num_processes, len(probs_split))
|
||||
beam_search_results = ctc_beam_search_decoder_batch(
|
||||
probs_split=probs_split,
|
||||
vocabulary=vocab_list,
|
||||
beam_size=beam_size,
|
||||
num_processes=num_processes,
|
||||
ext_scoring_func=self._ext_scorer,
|
||||
cutoff_prob=cutoff_prob,
|
||||
cutoff_top_n=cutoff_top_n)
|
||||
|
||||
results = [result[0][1] for result in beam_search_results]
|
||||
return results
|
||||
|
||||
def _adapt_feeding_dict(self, feeding_dict):
|
||||
"""Adapt feeding dict according to network struct.
|
||||
|
||||
To remove impacts from padding part, we add scale_sub_region layer and
|
||||
sub_seq layer. For sub_seq layer, 'sequence_offset' and
|
||||
'sequence_length' fields are appended. For each scale_sub_region layer
|
||||
'convN_index_range' field is appended.
|
||||
|
||||
:param feeding_dict: Feeding is a map of field name and tuple index
|
||||
of the data that reader returns.
|
||||
:type feeding_dict: dict|list
|
||||
:return: Adapted feeding dict.
|
||||
:rtype: dict|list
|
||||
"""
|
||||
adapted_feeding_dict = copy.deepcopy(feeding_dict)
|
||||
if isinstance(feeding_dict, dict):
|
||||
adapted_feeding_dict["sequence_offset"] = len(adapted_feeding_dict)
|
||||
adapted_feeding_dict["sequence_length"] = len(adapted_feeding_dict)
|
||||
for i in xrange(self._num_conv_layers):
|
||||
adapted_feeding_dict["conv%d_index_range" %i] = \
|
||||
len(adapted_feeding_dict)
|
||||
elif isinstance(feeding_dict, list):
|
||||
adapted_feeding_dict.append("sequence_offset")
|
||||
adapted_feeding_dict.append("sequence_length")
|
||||
for i in xrange(self._num_conv_layers):
|
||||
adapted_feeding_dict.append("conv%d_index_range" % i)
|
||||
else:
|
||||
raise ValueError("Type of feeding_dict is %s, not supported." %
|
||||
type(feeding_dict))
|
||||
|
||||
return adapted_feeding_dict
|
||||
|
||||
def _adapt_data(self, data):
|
||||
"""Adapt data according to network struct.
|
||||
|
||||
For each convolution layer in the conv_group, to remove impacts from
|
||||
padding data, we can multiply zero to the padding part of the outputs
|
||||
of each batch normalization layer. We add a scale_sub_region layer after
|
||||
each batch normalization layer to reset the padding data.
|
||||
For rnn layers, to remove impacts from padding data, we can truncate the
|
||||
padding part before output data feeded into the first rnn layer. We use
|
||||
sub_seq layer to achieve this.
|
||||
|
||||
:param data: Data from data_provider.
|
||||
:type data: list|function
|
||||
:return: Adapted data.
|
||||
:rtype: list|function
|
||||
"""
|
||||
|
||||
def adapt_instance(instance):
|
||||
if len(instance) < 2 or len(instance) > 3:
|
||||
raise ValueError("Size of instance should be 2 or 3.")
|
||||
padded_audio = instance[0]
|
||||
text = instance[1]
|
||||
# no padding part
|
||||
if len(instance) == 2:
|
||||
audio_len = padded_audio.shape[1]
|
||||
else:
|
||||
audio_len = instance[2]
|
||||
adapted_instance = [padded_audio, text]
|
||||
# Stride size for conv0 is (3, 2)
|
||||
# Stride size for conv1 to convN is (1, 2)
|
||||
# Same as the network, hard-coded here
|
||||
padded_conv0_h = (padded_audio.shape[0] - 1) // 2 + 1
|
||||
padded_conv0_w = (padded_audio.shape[1] - 1) // 3 + 1
|
||||
valid_w = (audio_len - 1) // 3 + 1
|
||||
adapted_instance += [
|
||||
[0], # sequence offset, always 0
|
||||
[valid_w], # valid sequence length
|
||||
# Index ranges for channel, height and width
|
||||
# Please refer scale_sub_region layer to see details
|
||||
[1, 32, 1, padded_conv0_h, valid_w + 1, padded_conv0_w]
|
||||
]
|
||||
pre_padded_h = padded_conv0_h
|
||||
for i in xrange(self._num_conv_layers - 1):
|
||||
padded_h = (pre_padded_h - 1) // 2 + 1
|
||||
pre_padded_h = padded_h
|
||||
adapted_instance += [
|
||||
[1, 32, 1, padded_h, valid_w + 1, padded_conv0_w]
|
||||
]
|
||||
return adapted_instance
|
||||
|
||||
if isinstance(data, list):
|
||||
return map(adapt_instance, data)
|
||||
elif inspect.isgeneratorfunction(data):
|
||||
|
||||
def adapted_reader():
|
||||
for instance in data():
|
||||
yield map(adapt_instance, instance)
|
||||
|
||||
return adapted_reader
|
||||
else:
|
||||
raise ValueError("Type of data is %s, not supported." % type(data))
|
||||
|
||||
def _create_parameters(self, model_path=None):
|
||||
"""Load or create model parameters."""
|
||||
if model_path is None:
|
||||
self._parameters = paddle.parameters.create(self._loss)
|
||||
else:
|
||||
self._parameters = paddle.parameters.Parameters.from_tar(
|
||||
gzip.open(model_path))
|
||||
|
||||
def _create_network(self, vocab_size, num_conv_layers, num_rnn_layers,
|
||||
rnn_layer_size, use_gru, share_rnn_weights):
|
||||
"""Create data layers and model network."""
|
||||
# paddle.data_type.dense_array is used for variable batch input.
|
||||
# The size 161 * 161 is only an placeholder value and the real shape
|
||||
# of input batch data will be induced during training.
|
||||
audio_data = paddle.layer.data(
|
||||
name="audio_spectrogram",
|
||||
type=paddle.data_type.dense_array(161 * 161))
|
||||
text_data = paddle.layer.data(
|
||||
name="transcript_text",
|
||||
type=paddle.data_type.integer_value_sequence(vocab_size))
|
||||
seq_offset_data = paddle.layer.data(
|
||||
name='sequence_offset',
|
||||
type=paddle.data_type.integer_value_sequence(1))
|
||||
seq_len_data = paddle.layer.data(
|
||||
name='sequence_length',
|
||||
type=paddle.data_type.integer_value_sequence(1))
|
||||
index_range_datas = []
|
||||
for i in xrange(num_rnn_layers):
|
||||
index_range_datas.append(
|
||||
paddle.layer.data(
|
||||
name='conv%d_index_range' % i,
|
||||
type=paddle.data_type.dense_vector(6)))
|
||||
|
||||
self._log_probs, self._loss = deep_speech_v2_network(
|
||||
audio_data=audio_data,
|
||||
text_data=text_data,
|
||||
seq_offset_data=seq_offset_data,
|
||||
seq_len_data=seq_len_data,
|
||||
index_range_datas=index_range_datas,
|
||||
dict_size=vocab_size,
|
||||
num_conv_layers=num_conv_layers,
|
||||
num_rnn_layers=num_rnn_layers,
|
||||
rnn_size=rnn_layer_size,
|
||||
use_gru=use_gru,
|
||||
share_rnn_weights=share_rnn_weights)
|
@ -0,0 +1,302 @@
|
||||
"""Contains DeepSpeech2 layers and networks."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle.v2 as paddle
|
||||
|
||||
|
||||
def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride,
|
||||
padding, act, index_range_data):
|
||||
"""Convolution layer with batch normalization.
|
||||
|
||||
:param input: Input layer.
|
||||
:type input: LayerOutput
|
||||
:param filter_size: The x dimension of a filter kernel. Or input a tuple for
|
||||
two image dimension.
|
||||
:type filter_size: int|tuple|list
|
||||
:param num_channels_in: Number of input channels.
|
||||
:type num_channels_in: int
|
||||
:type num_channels_out: Number of output channels.
|
||||
:type num_channels_in: out
|
||||
:param padding: The x dimension of the padding. Or input a tuple for two
|
||||
image dimension.
|
||||
:type padding: int|tuple|list
|
||||
:param act: Activation type.
|
||||
:type act: BaseActivation
|
||||
:param index_range_data: Index range to indicate sub region.
|
||||
:type index_range_data: LayerOutput
|
||||
:return: Batch norm layer after convolution layer.
|
||||
:rtype: LayerOutput
|
||||
"""
|
||||
conv_layer = paddle.layer.img_conv(
|
||||
input=input,
|
||||
filter_size=filter_size,
|
||||
num_channels=num_channels_in,
|
||||
num_filters=num_channels_out,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
act=paddle.activation.Linear(),
|
||||
bias_attr=False)
|
||||
batch_norm = paddle.layer.batch_norm(input=conv_layer, act=act)
|
||||
# reset padding part to 0
|
||||
scale_sub_region = paddle.layer.scale_sub_region(
|
||||
batch_norm, index_range_data, value=0.0)
|
||||
return scale_sub_region
|
||||
|
||||
|
||||
def bidirectional_simple_rnn_bn_layer(name, input, size, act, share_weights):
|
||||
"""Bidirectonal simple rnn layer with sequence-wise batch normalization.
|
||||
The batch normalization is only performed on input-state weights.
|
||||
|
||||
:param name: Name of the layer.
|
||||
:type name: string
|
||||
:param input: Input layer.
|
||||
:type input: LayerOutput
|
||||
:param size: Number of RNN cells.
|
||||
:type size: int
|
||||
:param act: Activation type.
|
||||
:type act: BaseActivation
|
||||
:param share_weights: Whether to share input-hidden weights between
|
||||
forward and backward directional RNNs.
|
||||
:type share_weights: bool
|
||||
:return: Bidirectional simple rnn layer.
|
||||
:rtype: LayerOutput
|
||||
"""
|
||||
if share_weights:
|
||||
# input-hidden weights shared between bi-direcitonal rnn.
|
||||
input_proj = paddle.layer.fc(
|
||||
input=input,
|
||||
size=size,
|
||||
act=paddle.activation.Linear(),
|
||||
bias_attr=False)
|
||||
# batch norm is only performed on input-state projection
|
||||
input_proj_bn = paddle.layer.batch_norm(
|
||||
input=input_proj, act=paddle.activation.Linear())
|
||||
# forward and backward in time
|
||||
forward_simple_rnn = paddle.layer.recurrent(
|
||||
input=input_proj_bn, act=act, reverse=False)
|
||||
backward_simple_rnn = paddle.layer.recurrent(
|
||||
input=input_proj_bn, act=act, reverse=True)
|
||||
|
||||
else:
|
||||
input_proj_forward = paddle.layer.fc(
|
||||
input=input,
|
||||
size=size,
|
||||
act=paddle.activation.Linear(),
|
||||
bias_attr=False)
|
||||
input_proj_backward = paddle.layer.fc(
|
||||
input=input,
|
||||
size=size,
|
||||
act=paddle.activation.Linear(),
|
||||
bias_attr=False)
|
||||
# batch norm is only performed on input-state projection
|
||||
input_proj_bn_forward = paddle.layer.batch_norm(
|
||||
input=input_proj_forward, act=paddle.activation.Linear())
|
||||
input_proj_bn_backward = paddle.layer.batch_norm(
|
||||
input=input_proj_backward, act=paddle.activation.Linear())
|
||||
# forward and backward in time
|
||||
forward_simple_rnn = paddle.layer.recurrent(
|
||||
input=input_proj_bn_forward, act=act, reverse=False)
|
||||
backward_simple_rnn = paddle.layer.recurrent(
|
||||
input=input_proj_bn_backward, act=act, reverse=True)
|
||||
|
||||
return paddle.layer.concat(input=[forward_simple_rnn, backward_simple_rnn])
|
||||
|
||||
|
||||
def bidirectional_gru_bn_layer(name, input, size, act):
|
||||
"""Bidirectonal gru layer with sequence-wise batch normalization.
|
||||
The batch normalization is only performed on input-state weights.
|
||||
|
||||
:param name: Name of the layer.
|
||||
:type name: string
|
||||
:param input: Input layer.
|
||||
:type input: LayerOutput
|
||||
:param size: Number of RNN cells.
|
||||
:type size: int
|
||||
:param act: Activation type.
|
||||
:type act: BaseActivation
|
||||
:return: Bidirectional simple rnn layer.
|
||||
:rtype: LayerOutput
|
||||
"""
|
||||
input_proj_forward = paddle.layer.fc(
|
||||
input=input,
|
||||
size=size * 3,
|
||||
act=paddle.activation.Linear(),
|
||||
bias_attr=False)
|
||||
input_proj_backward = paddle.layer.fc(
|
||||
input=input,
|
||||
size=size * 3,
|
||||
act=paddle.activation.Linear(),
|
||||
bias_attr=False)
|
||||
# batch norm is only performed on input-related projections
|
||||
input_proj_bn_forward = paddle.layer.batch_norm(
|
||||
input=input_proj_forward, act=paddle.activation.Linear())
|
||||
input_proj_bn_backward = paddle.layer.batch_norm(
|
||||
input=input_proj_backward, act=paddle.activation.Linear())
|
||||
# forward and backward in time
|
||||
forward_gru = paddle.layer.grumemory(
|
||||
input=input_proj_bn_forward, act=act, reverse=False)
|
||||
backward_gru = paddle.layer.grumemory(
|
||||
input=input_proj_bn_backward, act=act, reverse=True)
|
||||
return paddle.layer.concat(input=[forward_gru, backward_gru])
|
||||
|
||||
|
||||
def conv_group(input, num_stacks, index_range_datas):
|
||||
"""Convolution group with stacked convolution layers.
|
||||
|
||||
:param input: Input layer.
|
||||
:type input: LayerOutput
|
||||
:param num_stacks: Number of stacked convolution layers.
|
||||
:type num_stacks: int
|
||||
:param index_range_datas: Index ranges for each convolution layer.
|
||||
:type index_range_datas: tuple|list
|
||||
:return: Output layer of the convolution group.
|
||||
:rtype: LayerOutput
|
||||
"""
|
||||
conv = conv_bn_layer(
|
||||
input=input,
|
||||
filter_size=(11, 41),
|
||||
num_channels_in=1,
|
||||
num_channels_out=32,
|
||||
stride=(3, 2),
|
||||
padding=(5, 20),
|
||||
act=paddle.activation.BRelu(),
|
||||
index_range_data=index_range_datas[0])
|
||||
for i in xrange(num_stacks - 1):
|
||||
conv = conv_bn_layer(
|
||||
input=conv,
|
||||
filter_size=(11, 21),
|
||||
num_channels_in=32,
|
||||
num_channels_out=32,
|
||||
stride=(1, 2),
|
||||
padding=(5, 10),
|
||||
act=paddle.activation.BRelu(),
|
||||
index_range_data=index_range_datas[i + 1])
|
||||
output_num_channels = 32
|
||||
output_height = 160 // pow(2, num_stacks) + 1
|
||||
return conv, output_num_channels, output_height
|
||||
|
||||
|
||||
def rnn_group(input, size, num_stacks, use_gru, share_rnn_weights):
|
||||
"""RNN group with stacked bidirectional simple RNN layers.
|
||||
|
||||
:param input: Input layer.
|
||||
:type input: LayerOutput
|
||||
:param size: Number of RNN cells in each layer.
|
||||
:type size: int
|
||||
:param num_stacks: Number of stacked rnn layers.
|
||||
:type num_stacks: int
|
||||
:param use_gru: Use gru if set True. Use simple rnn if set False.
|
||||
:type use_gru: bool
|
||||
:param share_rnn_weights: Whether to share input-hidden weights between
|
||||
forward and backward directional RNNs.
|
||||
It is only available when use_gru=False.
|
||||
:type share_weights: bool
|
||||
:return: Output layer of the RNN group.
|
||||
:rtype: LayerOutput
|
||||
"""
|
||||
output = input
|
||||
for i in xrange(num_stacks):
|
||||
if use_gru:
|
||||
output = bidirectional_gru_bn_layer(
|
||||
name=str(i),
|
||||
input=output,
|
||||
size=size,
|
||||
act=paddle.activation.Relu())
|
||||
# BRelu does not support hppl, need to add later. Use Relu instead.
|
||||
else:
|
||||
output = bidirectional_simple_rnn_bn_layer(
|
||||
name=str(i),
|
||||
input=output,
|
||||
size=size,
|
||||
act=paddle.activation.BRelu(),
|
||||
share_weights=share_rnn_weights)
|
||||
return output
|
||||
|
||||
|
||||
def deep_speech_v2_network(audio_data,
|
||||
text_data,
|
||||
seq_offset_data,
|
||||
seq_len_data,
|
||||
index_range_datas,
|
||||
dict_size,
|
||||
num_conv_layers=2,
|
||||
num_rnn_layers=3,
|
||||
rnn_size=256,
|
||||
use_gru=False,
|
||||
share_rnn_weights=True):
|
||||
"""The DeepSpeech2 network structure.
|
||||
|
||||
:param audio_data: Audio spectrogram data layer.
|
||||
:type audio_data: LayerOutput
|
||||
:param text_data: Transcription text data layer.
|
||||
:type text_data: LayerOutput
|
||||
:param seq_offset_data: Sequence offset data layer.
|
||||
:type seq_offset_data: LayerOutput
|
||||
:param seq_len_data: Valid sequence length data layer.
|
||||
:type seq_len_data: LayerOutput
|
||||
:param index_range_datas: Index ranges data layers.
|
||||
:type index_range_datas: tuple|list
|
||||
:param dict_size: Dictionary size for tokenized transcription.
|
||||
:type dict_size: int
|
||||
:param num_conv_layers: Number of stacking convolution layers.
|
||||
:type num_conv_layers: int
|
||||
:param num_rnn_layers: Number of stacking RNN layers.
|
||||
:type num_rnn_layers: int
|
||||
:param rnn_size: RNN layer size (number of RNN cells).
|
||||
:type rnn_size: int
|
||||
:param use_gru: Use gru if set True. Use simple rnn if set False.
|
||||
:type use_gru: bool
|
||||
:param share_rnn_weights: Whether to share input-hidden weights between
|
||||
forward and backward direction RNNs.
|
||||
It is only available when use_gru=False.
|
||||
:type share_weights: bool
|
||||
:return: A tuple of an output unnormalized log probability layer (
|
||||
before softmax) and a ctc cost layer.
|
||||
:rtype: tuple of LayerOutput
|
||||
"""
|
||||
# convolution group
|
||||
conv_group_output, conv_group_num_channels, conv_group_height = conv_group(
|
||||
input=audio_data,
|
||||
num_stacks=num_conv_layers,
|
||||
index_range_datas=index_range_datas)
|
||||
# convert data form convolution feature map to sequence of vectors
|
||||
conv2seq = paddle.layer.block_expand(
|
||||
input=conv_group_output,
|
||||
num_channels=conv_group_num_channels,
|
||||
stride_x=1,
|
||||
stride_y=1,
|
||||
block_x=1,
|
||||
block_y=conv_group_height)
|
||||
# remove padding part
|
||||
remove_padding_data = paddle.layer.sub_seq(
|
||||
input=conv2seq,
|
||||
offsets=seq_offset_data,
|
||||
sizes=seq_len_data,
|
||||
act=paddle.activation.Linear(),
|
||||
bias_attr=False)
|
||||
# rnn group
|
||||
rnn_group_output = rnn_group(
|
||||
input=remove_padding_data,
|
||||
size=rnn_size,
|
||||
num_stacks=num_rnn_layers,
|
||||
use_gru=use_gru,
|
||||
share_rnn_weights=share_rnn_weights)
|
||||
fc = paddle.layer.fc(
|
||||
input=rnn_group_output,
|
||||
size=dict_size + 1,
|
||||
act=paddle.activation.Linear(),
|
||||
bias_attr=True)
|
||||
# probability distribution with softmax
|
||||
log_probs = paddle.layer.mixed(
|
||||
input=paddle.layer.identity_projection(input=fc),
|
||||
act=paddle.activation.Softmax())
|
||||
# ctc cost
|
||||
ctc_loss = paddle.layer.warp_ctc(
|
||||
input=fc,
|
||||
label=text_data,
|
||||
size=dict_size + 1,
|
||||
blank=dict_size,
|
||||
norm_by_times=True)
|
||||
return log_probs, ctc_loss
|
@ -0,0 +1,19 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
. ../../utils/utility.sh
|
||||
|
||||
URL='https://deepspeech.bj.bcebos.com/mandarin_models/aishell_model.tar.gz'
|
||||
MD5=0ee83aa15fba421e5de8fc66c8feb350
|
||||
TARGET=./aishell_model.tar.gz
|
||||
|
||||
|
||||
echo "Download Aishell model ..."
|
||||
download $URL $MD5 $TARGET
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Fail to download Aishell model!"
|
||||
exit 1
|
||||
fi
|
||||
tar -zxvf $TARGET
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,19 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
. ../../utils/utility.sh
|
||||
|
||||
URL='https://deepspeech.bj.bcebos.com/demo_models/baidu_en8k_model.tar.gz'
|
||||
MD5=5fe7639e720d51b3c3bdf7a1470c6272
|
||||
TARGET=./baidu_en8k_model.tar.gz
|
||||
|
||||
|
||||
echo "Download BaiduEn8k model ..."
|
||||
download $URL $MD5 $TARGET
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Fail to download BaiduEn8k model!"
|
||||
exit 1
|
||||
fi
|
||||
tar -zxvf $TARGET
|
||||
|
||||
|
||||
exit 0
|
@ -0,0 +1,19 @@
|
||||
#! /usr/bin/env bash
|
||||
|
||||
. ../../utils/utility.sh
|
||||
|
||||
URL='https://deepspeech.bj.bcebos.com/eng_models/librispeech_model.tar.gz'
|
||||
MD5=1f72d0c5591f453362f0caa09dd57618
|
||||
TARGET=./librispeech_model.tar.gz
|
||||
|
||||
|
||||
echo "Download LibriSpeech model ..."
|
||||
download $URL $MD5 $TARGET
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Fail to download LibriSpeech model!"
|
||||
exit 1
|
||||
fi
|
||||
tar -zxvf $TARGET
|
||||
|
||||
|
||||
exit 0
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue