commit
77a456b143
@ -0,0 +1,29 @@
|
|||||||
|
# This file is used by clang-format to autoformat paddle source code
|
||||||
|
#
|
||||||
|
# The clang-format is part of llvm toolchain.
|
||||||
|
# It need to install llvm and clang to format source code style.
|
||||||
|
#
|
||||||
|
# The basic usage is,
|
||||||
|
# clang-format -i -style=file PATH/TO/SOURCE/CODE
|
||||||
|
#
|
||||||
|
# The -style=file implicit use ".clang-format" file located in one of
|
||||||
|
# parent directory.
|
||||||
|
# The -i means inplace change.
|
||||||
|
#
|
||||||
|
# The document of clang-format is
|
||||||
|
# http://clang.llvm.org/docs/ClangFormat.html
|
||||||
|
# http://clang.llvm.org/docs/ClangFormatStyleOptions.html
|
||||||
|
---
|
||||||
|
Language: Cpp
|
||||||
|
BasedOnStyle: Google
|
||||||
|
IndentWidth: 2
|
||||||
|
TabWidth: 2
|
||||||
|
ContinuationIndentWidth: 4
|
||||||
|
MaxEmptyLinesToKeep: 2
|
||||||
|
AccessModifierOffset: -2 # The private/protected/public has no indent in class
|
||||||
|
Standard: Cpp11
|
||||||
|
AllowAllParametersOfDeclarationOnNextLine: true
|
||||||
|
BinPackParameters: false
|
||||||
|
BinPackArguments: false
|
||||||
|
...
|
||||||
|
|
@ -0,0 +1,15 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
readonly VERSION="3.9"
|
||||||
|
|
||||||
|
version=$(clang-format -version)
|
||||||
|
|
||||||
|
if ! [[ $version == *"$VERSION"* ]]; then
|
||||||
|
echo "clang-format version check failed."
|
||||||
|
echo "a version contains '$VERSION' is needed, but get '$version'"
|
||||||
|
echo "you can install the right version, and make an soft-link to '\$PATH' env"
|
||||||
|
exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
clang-format $@
|
@ -0,0 +1,2 @@
|
|||||||
|
.DS_Store
|
||||||
|
*.pyc
|
@ -0,0 +1,35 @@
|
|||||||
|
- repo: https://github.com/pre-commit/mirrors-yapf.git
|
||||||
|
sha: v0.16.0
|
||||||
|
hooks:
|
||||||
|
- id: yapf
|
||||||
|
files: \.py$
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
sha: a11d9314b22d8f8c7556443875b731ef05965464
|
||||||
|
hooks:
|
||||||
|
- id: check-merge-conflict
|
||||||
|
- id: check-symlinks
|
||||||
|
- id: detect-private-key
|
||||||
|
files: (?!.*paddle)^.*$
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
files: \.md$
|
||||||
|
- id: trailing-whitespace
|
||||||
|
files: \.md$
|
||||||
|
- repo: https://github.com/Lucas-C/pre-commit-hooks
|
||||||
|
sha: v1.0.1
|
||||||
|
hooks:
|
||||||
|
- id: forbid-crlf
|
||||||
|
files: \.md$
|
||||||
|
- id: remove-crlf
|
||||||
|
files: \.md$
|
||||||
|
- id: forbid-tabs
|
||||||
|
files: \.md$
|
||||||
|
- id: remove-tabs
|
||||||
|
files: \.md$
|
||||||
|
- repo: local
|
||||||
|
hooks:
|
||||||
|
- id: clang-format
|
||||||
|
name: clang-format
|
||||||
|
description: Format files with ClangFormat
|
||||||
|
entry: bash .clang_format.hook -i
|
||||||
|
language: system
|
||||||
|
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
|
@ -0,0 +1,3 @@
|
|||||||
|
[style]
|
||||||
|
based_on_style = pep8
|
||||||
|
column_limit = 80
|
@ -0,0 +1,34 @@
|
|||||||
|
language: cpp
|
||||||
|
cache: ccache
|
||||||
|
sudo: required
|
||||||
|
dist: trusty
|
||||||
|
services:
|
||||||
|
- docker
|
||||||
|
os:
|
||||||
|
- linux
|
||||||
|
env:
|
||||||
|
- JOB=PRE_COMMIT
|
||||||
|
|
||||||
|
addons:
|
||||||
|
apt:
|
||||||
|
packages:
|
||||||
|
- git
|
||||||
|
- python
|
||||||
|
- python-pip
|
||||||
|
- python2.7-dev
|
||||||
|
|
||||||
|
before_install:
|
||||||
|
- sudo pip install -U virtualenv pre-commit pip
|
||||||
|
- docker pull paddlepaddle/paddle:latest
|
||||||
|
|
||||||
|
script:
|
||||||
|
- exit_code=0
|
||||||
|
- .travis/precommit.sh || exit_code=$(( exit_code | $? ))
|
||||||
|
- docker run -i --rm -v "$PWD:/py_unittest" paddlepaddle/paddle:latest /bin/bash -c
|
||||||
|
'cd /py_unittest; sh .travis/unittest.sh' || exit_code=$(( exit_code | $? ))
|
||||||
|
exit $exit_code
|
||||||
|
|
||||||
|
notifications:
|
||||||
|
email:
|
||||||
|
on_success: change
|
||||||
|
on_failure: always
|
@ -0,0 +1,21 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
function abort(){
|
||||||
|
echo "Your commit not fit PaddlePaddle code style" 1>&2
|
||||||
|
echo "Please use pre-commit scripts to auto-format your code" 1>&2
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
trap 'abort' 0
|
||||||
|
set -e
|
||||||
|
cd `dirname $0`
|
||||||
|
cd ..
|
||||||
|
export PATH=/usr/bin:$PATH
|
||||||
|
pre-commit install
|
||||||
|
|
||||||
|
if ! pre-commit run -a ; then
|
||||||
|
ls -lh
|
||||||
|
git diff --exit-code
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
trap : 0
|
@ -0,0 +1,29 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
abort(){
|
||||||
|
echo "Run unittest failed" 1>&2
|
||||||
|
echo "Please check your code" 1>&2
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
unittest(){
|
||||||
|
cd $1 > /dev/null
|
||||||
|
if [ -f "setup.sh" ]; then
|
||||||
|
sh setup.sh
|
||||||
|
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
|
||||||
|
fi
|
||||||
|
if [ $? != 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
find . -name 'tests' -type d -print0 | \
|
||||||
|
xargs -0 -I{} -n1 bash -c \
|
||||||
|
'python -m unittest discover -v -s {}'
|
||||||
|
cd - > /dev/null
|
||||||
|
}
|
||||||
|
|
||||||
|
trap 'abort' 0
|
||||||
|
set -e
|
||||||
|
|
||||||
|
unittest .
|
||||||
|
|
||||||
|
trap : 0
|
@ -0,0 +1,17 @@
|
|||||||
|
"""Set up paths for DS2"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os.path
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def add_path(path):
|
||||||
|
if path not in sys.path:
|
||||||
|
sys.path.insert(0, path)
|
||||||
|
|
||||||
|
|
||||||
|
this_dir = os.path.dirname(__file__)
|
||||||
|
proj_path = os.path.join(this_dir, '..')
|
||||||
|
add_path(proj_path)
|
@ -0,0 +1,29 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
TRAIN_MANIFEST="cloud/cloud_manifests/cloud.manifest.train"
|
||||||
|
DEV_MANIFEST="cloud/cloud_manifests/cloud.manifest.dev"
|
||||||
|
CLOUD_MODEL_DIR="./checkpoints"
|
||||||
|
BATCH_SIZE=512
|
||||||
|
NUM_GPU=8
|
||||||
|
NUM_NODE=1
|
||||||
|
IS_LOCAL="True"
|
||||||
|
|
||||||
|
JOB_NAME=deepspeech-`date +%Y%m%d%H%M%S`
|
||||||
|
DS2_PATH=${PWD%/*}
|
||||||
|
cp -f pcloud_train.sh ${DS2_PATH}
|
||||||
|
|
||||||
|
paddlecloud submit \
|
||||||
|
-image bootstrapper:5000/paddlepaddle/pcloud_ds2:latest \
|
||||||
|
-jobname ${JOB_NAME} \
|
||||||
|
-cpu ${NUM_GPU} \
|
||||||
|
-gpu ${NUM_GPU} \
|
||||||
|
-memory 64Gi \
|
||||||
|
-parallelism ${NUM_NODE} \
|
||||||
|
-pscpu 1 \
|
||||||
|
-pservers 1 \
|
||||||
|
-psmemory 64Gi \
|
||||||
|
-passes 1 \
|
||||||
|
-entry "sh pcloud_train.sh ${TRAIN_MANIFEST} ${DEV_MANIFEST} ${CLOUD_MODEL_DIR} ${NUM_GPU} ${BATCH_SIZE} ${IS_LOCAL}" \
|
||||||
|
${DS2_PATH}
|
||||||
|
|
||||||
|
rm ${DS2_PATH}/pcloud_train.sh
|
@ -0,0 +1,46 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
TRAIN_MANIFEST=$1
|
||||||
|
DEV_MANIFEST=$2
|
||||||
|
MODEL_PATH=$3
|
||||||
|
NUM_GPU=$4
|
||||||
|
BATCH_SIZE=$5
|
||||||
|
IS_LOCAL=$6
|
||||||
|
|
||||||
|
python ./cloud/split_data.py \
|
||||||
|
--in_manifest_path=${TRAIN_MANIFEST} \
|
||||||
|
--out_manifest_path='/local.manifest.train'
|
||||||
|
|
||||||
|
python ./cloud/split_data.py \
|
||||||
|
--in_manifest_path=${DEV_MANIFEST} \
|
||||||
|
--out_manifest_path='/local.manifest.dev'
|
||||||
|
|
||||||
|
mkdir ./logs
|
||||||
|
|
||||||
|
python -u train.py \
|
||||||
|
--batch_size=${BATCH_SIZE} \
|
||||||
|
--trainer_count=${NUM_GPU} \
|
||||||
|
--num_passes=200 \
|
||||||
|
--num_proc_data=${NUM_GPU} \
|
||||||
|
--num_conv_layers=2 \
|
||||||
|
--num_rnn_layers=3 \
|
||||||
|
--rnn_layer_size=2048 \
|
||||||
|
--num_iter_print=100 \
|
||||||
|
--learning_rate=5e-4 \
|
||||||
|
--max_duration=27.0 \
|
||||||
|
--min_duration=0.0 \
|
||||||
|
--use_sortagrad=True \
|
||||||
|
--use_gru=False \
|
||||||
|
--use_gpu=True \
|
||||||
|
--is_local=${IS_LOCAL} \
|
||||||
|
--share_rnn_weights=True \
|
||||||
|
--train_manifest='/local.manifest.train' \
|
||||||
|
--dev_manifest='/local.manifest.dev' \
|
||||||
|
--mean_std_path='data/librispeech/mean_std.npz' \
|
||||||
|
--vocab_path='data/librispeech/vocab.txt' \
|
||||||
|
--output_model_dir='./checkpoints' \
|
||||||
|
--output_model_dir=${MODEL_PATH} \
|
||||||
|
--augment_conf_path='conf/augmentation.config' \
|
||||||
|
--specgram_type='linear' \
|
||||||
|
--shuffle_method='batch_shuffle_clipped' \
|
||||||
|
2>&1 | tee ./logs/train.log
|
@ -0,0 +1,22 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
mkdir cloud_manifests
|
||||||
|
|
||||||
|
IN_MANIFESTS="../data/librispeech/manifest.train ../data/librispeech/manifest.dev-clean ../data/librispeech/manifest.test-clean"
|
||||||
|
OUT_MANIFESTS="cloud_manifests/cloud.manifest.train cloud_manifests/cloud.manifest.dev cloud_manifests/cloud.manifest.test"
|
||||||
|
CLOUD_DATA_DIR="/pfs/dlnel/home/USERNAME/deepspeech2/data/librispeech"
|
||||||
|
NUM_SHARDS=50
|
||||||
|
|
||||||
|
python upload_data.py \
|
||||||
|
--in_manifest_paths ${IN_MANIFESTS} \
|
||||||
|
--out_manifest_paths ${OUT_MANIFESTS} \
|
||||||
|
--cloud_data_dir ${CLOUD_DATA_DIR} \
|
||||||
|
--num_shards ${NUM_SHARDS}
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]
|
||||||
|
then
|
||||||
|
echo "Upload Data Failed!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "All Done."
|
@ -0,0 +1,41 @@
|
|||||||
|
"""This tool is used for splitting data into each node of
|
||||||
|
paddlecloud. This script should be called in paddlecloud.
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--in_manifest_path",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Input manifest path for all nodes.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--out_manifest_path",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Output manifest file path for current node.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def split_data(in_manifest_path, out_manifest_path):
|
||||||
|
with open("/trainer_id", "r") as f:
|
||||||
|
trainer_id = int(f.readline()[:-1])
|
||||||
|
with open("/trainer_count", "r") as f:
|
||||||
|
trainer_count = int(f.readline()[:-1])
|
||||||
|
|
||||||
|
out_manifest = []
|
||||||
|
for index, json_line in enumerate(open(in_manifest_path, 'r')):
|
||||||
|
if (index % trainer_count) == trainer_id:
|
||||||
|
out_manifest.append("%s\n" % json_line.strip())
|
||||||
|
with open(out_manifest_path, 'w') as f:
|
||||||
|
f.writelines(out_manifest)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
split_data(args.in_manifest_path, args.out_manifest_path)
|
@ -0,0 +1,129 @@
|
|||||||
|
"""This script is for uploading data for DeepSpeech2 training on paddlecloud.
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
1. Read original manifests and extract local sound files.
|
||||||
|
2. Tar all local sound files into multiple tar files and upload them.
|
||||||
|
3. Modify original manifests with updated paths in cloud filesystem.
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tarfile
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
import shutil
|
||||||
|
from subprocess import call
|
||||||
|
import _init_paths
|
||||||
|
from data_utils.utils import read_manifest
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--in_manifest_paths",
|
||||||
|
default=[
|
||||||
|
"../datasets/manifest.train", "../datasets/manifest.dev",
|
||||||
|
"../datasets/manifest.test"
|
||||||
|
],
|
||||||
|
type=str,
|
||||||
|
nargs='+',
|
||||||
|
help="Local filepaths of input manifests to load, pack and upload."
|
||||||
|
"(default: %(default)s)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--out_manifest_paths",
|
||||||
|
default=[
|
||||||
|
"./cloud.manifest.train", "./cloud.manifest.dev",
|
||||||
|
"./cloud.manifest.test"
|
||||||
|
],
|
||||||
|
type=str,
|
||||||
|
nargs='+',
|
||||||
|
help="Local filepaths of modified manifests to write to. "
|
||||||
|
"(default: %(default)s)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--cloud_data_dir",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="Destination directory on paddlecloud to upload data to.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_shards",
|
||||||
|
default=10,
|
||||||
|
type=int,
|
||||||
|
help="Number of parts to split data to. (default: %(default)s)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--local_tmp_dir",
|
||||||
|
default="./tmp/",
|
||||||
|
type=str,
|
||||||
|
help="Local directory for storing temporary data. (default: %(default)s)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def upload_data(in_manifest_path_list, out_manifest_path_list, local_tmp_dir,
|
||||||
|
upload_tar_dir, num_shards):
|
||||||
|
"""Extract and pack sound files listed in the manifest files into multple
|
||||||
|
tar files and upload them to padldecloud. Besides, generate new manifest
|
||||||
|
files with updated paths in paddlecloud.
|
||||||
|
"""
|
||||||
|
# compute total audio number
|
||||||
|
total_line = 0
|
||||||
|
for manifest_path in in_manifest_path_list:
|
||||||
|
with open(manifest_path, 'r') as f:
|
||||||
|
total_line += len(f.readlines())
|
||||||
|
line_per_tar = (total_line // num_shards) + 1
|
||||||
|
|
||||||
|
# pack and upload shard by shard
|
||||||
|
line_count, tar_file = 0, None
|
||||||
|
for manifest_path, out_manifest_path in zip(in_manifest_path_list,
|
||||||
|
out_manifest_path_list):
|
||||||
|
manifest = read_manifest(manifest_path)
|
||||||
|
out_manifest = []
|
||||||
|
for json_data in manifest:
|
||||||
|
sound_filepath = json_data['audio_filepath']
|
||||||
|
sound_filename = os.path.basename(sound_filepath)
|
||||||
|
if line_count % line_per_tar == 0:
|
||||||
|
if tar_file != None:
|
||||||
|
tar_file.close()
|
||||||
|
pcloud_cp(tar_path, upload_tar_dir)
|
||||||
|
os.remove(tar_path)
|
||||||
|
tar_name = 'part-%s-of-%s.tar' % (
|
||||||
|
str(line_count // line_per_tar).zfill(5),
|
||||||
|
str(num_shards).zfill(5))
|
||||||
|
tar_path = os.path.join(local_tmp_dir, tar_name)
|
||||||
|
tar_file = tarfile.open(tar_path, 'w')
|
||||||
|
tar_file.add(sound_filepath, arcname=sound_filename)
|
||||||
|
line_count += 1
|
||||||
|
json_data['audio_filepath'] = "tar:%s#%s" % (
|
||||||
|
os.path.join(upload_tar_dir, tar_name), sound_filename)
|
||||||
|
out_manifest.append("%s\n" % json.dumps(json_data))
|
||||||
|
with open(out_manifest_path, 'w') as f:
|
||||||
|
f.writelines(out_manifest)
|
||||||
|
pcloud_cp(out_manifest_path, upload_tar_dir)
|
||||||
|
tar_file.close()
|
||||||
|
pcloud_cp(tar_path, upload_tar_dir)
|
||||||
|
os.remove(tar_path)
|
||||||
|
|
||||||
|
|
||||||
|
def pcloud_mkdir(dir):
|
||||||
|
"""Make directory in PaddleCloud filesystem.
|
||||||
|
"""
|
||||||
|
if call(['paddlecloud', 'mkdir', dir]) != 0:
|
||||||
|
raise IOError("PaddleCloud mkdir failed: %s." % dir)
|
||||||
|
|
||||||
|
|
||||||
|
def pcloud_cp(src, dst):
|
||||||
|
"""Copy src from local filesytem to dst in PaddleCloud filesystem,
|
||||||
|
or downlowd src from PaddleCloud filesystem to dst in local filesystem.
|
||||||
|
"""
|
||||||
|
if call(['paddlecloud', 'cp', src, dst]) != 0:
|
||||||
|
raise IOError("PaddleCloud cp failed: from [%s] to [%s]." % (src, dst))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
if not os.path.exists(args.local_tmp_dir):
|
||||||
|
os.makedirs(args.local_tmp_dir)
|
||||||
|
pcloud_mkdir(args.cloud_data_dir)
|
||||||
|
|
||||||
|
upload_data(args.in_manifest_paths, args.out_manifest_paths,
|
||||||
|
args.local_tmp_dir, args.cloud_data_dir, args.num_shards)
|
||||||
|
|
||||||
|
shutil.rmtree(args.local_tmp_dir)
|
@ -0,0 +1,8 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "shift",
|
||||||
|
"params": {"min_shift_ms": -5,
|
||||||
|
"max_shift_ms": 5},
|
||||||
|
"prob": 1.0
|
||||||
|
}
|
||||||
|
]
|
@ -0,0 +1,39 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "noise",
|
||||||
|
"params": {"min_snr_dB": 40,
|
||||||
|
"max_snr_dB": 50,
|
||||||
|
"noise_manifest_path": "datasets/manifest.noise"},
|
||||||
|
"prob": 0.6
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "impulse",
|
||||||
|
"params": {"impulse_manifest_path": "datasets/manifest.impulse"},
|
||||||
|
"prob": 0.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "speed",
|
||||||
|
"params": {"min_speed_rate": 0.95,
|
||||||
|
"max_speed_rate": 1.05},
|
||||||
|
"prob": 0.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "shift",
|
||||||
|
"params": {"min_shift_ms": -5,
|
||||||
|
"max_shift_ms": 5},
|
||||||
|
"prob": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "volume",
|
||||||
|
"params": {"min_gain_dBFS": -10,
|
||||||
|
"max_gain_dBFS": 10},
|
||||||
|
"prob": 0.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "bayesian_normal",
|
||||||
|
"params": {"target_db": -20,
|
||||||
|
"prior_db": -20,
|
||||||
|
"prior_samples": 100},
|
||||||
|
"prob": 0.0
|
||||||
|
}
|
||||||
|
]
|
@ -0,0 +1,110 @@
|
|||||||
|
"""Prepare Aishell mandarin dataset
|
||||||
|
|
||||||
|
Download, unpack and create manifest files.
|
||||||
|
Manifest file is a json-format file with each line containing the
|
||||||
|
meta data (i.e. audio filepath, transcript and audio duration)
|
||||||
|
of each audio file in the data set.
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import codecs
|
||||||
|
import soundfile
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
from data_utils.utility import download, unpack
|
||||||
|
|
||||||
|
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
|
||||||
|
|
||||||
|
URL_ROOT = 'http://www.openslr.org/resources/33'
|
||||||
|
DATA_URL = URL_ROOT + '/data_aishell.tgz'
|
||||||
|
MD5_DATA = '2f494334227864a8a8fec932999db9d8'
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--target_dir",
|
||||||
|
default=DATA_HOME + "/Aishell",
|
||||||
|
type=str,
|
||||||
|
help="Directory to save the dataset. (default: %(default)s)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--manifest_prefix",
|
||||||
|
default="manifest",
|
||||||
|
type=str,
|
||||||
|
help="Filepath prefix for output manifests. (default: %(default)s)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def create_manifest(data_dir, manifest_path_prefix):
|
||||||
|
print("Creating manifest %s ..." % manifest_path_prefix)
|
||||||
|
json_lines = []
|
||||||
|
transcript_path = os.path.join(data_dir, 'transcript',
|
||||||
|
'aishell_transcript_v0.8.txt')
|
||||||
|
transcript_dict = {}
|
||||||
|
for line in codecs.open(transcript_path, 'r', 'utf-8'):
|
||||||
|
line = line.strip()
|
||||||
|
if line == '': continue
|
||||||
|
audio_id, text = line.split(' ', 1)
|
||||||
|
# remove withespace
|
||||||
|
text = ''.join(text.split())
|
||||||
|
transcript_dict[audio_id] = text
|
||||||
|
|
||||||
|
data_types = ['train', 'dev', 'test']
|
||||||
|
for type in data_types:
|
||||||
|
del json_lines[:]
|
||||||
|
audio_dir = os.path.join(data_dir, 'wav', type)
|
||||||
|
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
|
||||||
|
for fname in filelist:
|
||||||
|
audio_path = os.path.join(subfolder, fname)
|
||||||
|
audio_id = fname[:-4]
|
||||||
|
# if no transcription for audio then skipped
|
||||||
|
if audio_id not in transcript_dict:
|
||||||
|
continue
|
||||||
|
audio_data, samplerate = soundfile.read(audio_path)
|
||||||
|
duration = float(len(audio_data) / samplerate)
|
||||||
|
text = transcript_dict[audio_id]
|
||||||
|
json_lines.append(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
'audio_filepath': audio_path,
|
||||||
|
'duration': duration,
|
||||||
|
'text': text
|
||||||
|
},
|
||||||
|
ensure_ascii=False))
|
||||||
|
manifest_path = manifest_path_prefix + '.' + type
|
||||||
|
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
|
||||||
|
for line in json_lines:
|
||||||
|
fout.write(line + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_dataset(url, md5sum, target_dir, manifest_path):
|
||||||
|
"""Download, unpack and create manifest file."""
|
||||||
|
data_dir = os.path.join(target_dir, 'data_aishell')
|
||||||
|
if not os.path.exists(data_dir):
|
||||||
|
filepath = download(url, md5sum, target_dir)
|
||||||
|
unpack(filepath, target_dir)
|
||||||
|
# unpack all audio tar files
|
||||||
|
audio_dir = os.path.join(data_dir, 'wav')
|
||||||
|
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
|
||||||
|
for ftar in filelist:
|
||||||
|
unpack(os.path.join(subfolder, ftar), subfolder, True)
|
||||||
|
else:
|
||||||
|
print("Skip downloading and unpacking. Data already exists in %s." %
|
||||||
|
target_dir)
|
||||||
|
create_manifest(data_dir, manifest_path)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if args.target_dir.startswith('~'):
|
||||||
|
args.target_dir = os.path.expanduser(args.target_dir)
|
||||||
|
|
||||||
|
prepare_dataset(
|
||||||
|
url=DATA_URL,
|
||||||
|
md5sum=MD5_DATA,
|
||||||
|
target_dir=args.target_dir,
|
||||||
|
manifest_path=args.manifest_prefix)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -0,0 +1,148 @@
|
|||||||
|
"""Prepare Librispeech ASR datasets.
|
||||||
|
|
||||||
|
Download, unpack and create manifest files.
|
||||||
|
Manifest file is a json-format file with each line containing the
|
||||||
|
meta data (i.e. audio filepath, transcript and audio duration)
|
||||||
|
of each audio file in the data set.
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import distutils.util
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
import soundfile
|
||||||
|
import json
|
||||||
|
import codecs
|
||||||
|
from data_utils.utility import download, unpack
|
||||||
|
|
||||||
|
URL_ROOT = "http://www.openslr.org/resources/12"
|
||||||
|
URL_TEST_CLEAN = URL_ROOT + "/test-clean.tar.gz"
|
||||||
|
URL_TEST_OTHER = URL_ROOT + "/test-other.tar.gz"
|
||||||
|
URL_DEV_CLEAN = URL_ROOT + "/dev-clean.tar.gz"
|
||||||
|
URL_DEV_OTHER = URL_ROOT + "/dev-other.tar.gz"
|
||||||
|
URL_TRAIN_CLEAN_100 = URL_ROOT + "/train-clean-100.tar.gz"
|
||||||
|
URL_TRAIN_CLEAN_360 = URL_ROOT + "/train-clean-360.tar.gz"
|
||||||
|
URL_TRAIN_OTHER_500 = URL_ROOT + "/train-other-500.tar.gz"
|
||||||
|
|
||||||
|
MD5_TEST_CLEAN = "32fa31d27d2e1cad72775fee3f4849a9"
|
||||||
|
MD5_TEST_OTHER = "fb5a50374b501bb3bac4815ee91d3135"
|
||||||
|
MD5_DEV_CLEAN = "42e2234ba48799c1f50f24a7926300a1"
|
||||||
|
MD5_DEV_OTHER = "c8d0bcc9cca99d4f8b62fcc847357931"
|
||||||
|
MD5_TRAIN_CLEAN_100 = "2a93770f6d5c6c964bc36631d331a522"
|
||||||
|
MD5_TRAIN_CLEAN_360 = "c0e676e450a7ff2f54aeade5171606fa"
|
||||||
|
MD5_TRAIN_OTHER_500 = "d1a0fd59409feb2c614ce4d30c387708"
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--target_dir",
|
||||||
|
default='~/.cache/paddle/dataset/speech/libri',
|
||||||
|
type=str,
|
||||||
|
help="Directory to save the dataset. (default: %(default)s)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--manifest_prefix",
|
||||||
|
default="manifest",
|
||||||
|
type=str,
|
||||||
|
help="Filepath prefix for output manifests. (default: %(default)s)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--full_download",
|
||||||
|
default="True",
|
||||||
|
type=distutils.util.strtobool,
|
||||||
|
help="Download all datasets for Librispeech."
|
||||||
|
" If False, only download a minimal requirement (test-clean, dev-clean"
|
||||||
|
" train-clean-100). (default: %(default)s)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def create_manifest(data_dir, manifest_path):
|
||||||
|
"""Create a manifest json file summarizing the data set, with each line
|
||||||
|
containing the meta data (i.e. audio filepath, transcription text, audio
|
||||||
|
duration) of each audio file within the data set.
|
||||||
|
"""
|
||||||
|
print("Creating manifest %s ..." % manifest_path)
|
||||||
|
json_lines = []
|
||||||
|
for subfolder, _, filelist in sorted(os.walk(data_dir)):
|
||||||
|
text_filelist = [
|
||||||
|
filename for filename in filelist if filename.endswith('trans.txt')
|
||||||
|
]
|
||||||
|
if len(text_filelist) > 0:
|
||||||
|
text_filepath = os.path.join(data_dir, subfolder, text_filelist[0])
|
||||||
|
for line in open(text_filepath):
|
||||||
|
segments = line.strip().split()
|
||||||
|
text = ' '.join(segments[1:]).lower()
|
||||||
|
audio_filepath = os.path.join(data_dir, subfolder,
|
||||||
|
segments[0] + '.flac')
|
||||||
|
audio_data, samplerate = soundfile.read(audio_filepath)
|
||||||
|
duration = float(len(audio_data)) / samplerate
|
||||||
|
json_lines.append(
|
||||||
|
json.dumps({
|
||||||
|
'audio_filepath': audio_filepath,
|
||||||
|
'duration': duration,
|
||||||
|
'text': text
|
||||||
|
}))
|
||||||
|
with codecs.open(manifest_path, 'w', 'utf-8') as out_file:
|
||||||
|
for line in json_lines:
|
||||||
|
out_file.write(line + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_dataset(url, md5sum, target_dir, manifest_path):
|
||||||
|
"""Download, unpack and create summmary manifest file.
|
||||||
|
"""
|
||||||
|
if not os.path.exists(os.path.join(target_dir, "LibriSpeech")):
|
||||||
|
# download
|
||||||
|
filepath = download(url, md5sum, target_dir)
|
||||||
|
# unpack
|
||||||
|
unpack(filepath, target_dir)
|
||||||
|
else:
|
||||||
|
print("Skip downloading and unpacking. Data already exists in %s." %
|
||||||
|
target_dir)
|
||||||
|
# create manifest json file
|
||||||
|
create_manifest(target_dir, manifest_path)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if args.target_dir.startswith('~'):
|
||||||
|
args.target_dir = os.path.expanduser(args.target_dir)
|
||||||
|
|
||||||
|
prepare_dataset(
|
||||||
|
url=URL_TEST_CLEAN,
|
||||||
|
md5sum=MD5_TEST_CLEAN,
|
||||||
|
target_dir=os.path.join(args.target_dir, "test-clean"),
|
||||||
|
manifest_path=args.manifest_prefix + ".test-clean")
|
||||||
|
prepare_dataset(
|
||||||
|
url=URL_DEV_CLEAN,
|
||||||
|
md5sum=MD5_DEV_CLEAN,
|
||||||
|
target_dir=os.path.join(args.target_dir, "dev-clean"),
|
||||||
|
manifest_path=args.manifest_prefix + ".dev-clean")
|
||||||
|
if args.full_download:
|
||||||
|
prepare_dataset(
|
||||||
|
url=URL_TRAIN_CLEAN_100,
|
||||||
|
md5sum=MD5_TRAIN_CLEAN_100,
|
||||||
|
target_dir=os.path.join(args.target_dir, "train-clean-100"),
|
||||||
|
manifest_path=args.manifest_prefix + ".train-clean-100")
|
||||||
|
prepare_dataset(
|
||||||
|
url=URL_TEST_OTHER,
|
||||||
|
md5sum=MD5_TEST_OTHER,
|
||||||
|
target_dir=os.path.join(args.target_dir, "test-other"),
|
||||||
|
manifest_path=args.manifest_prefix + ".test-other")
|
||||||
|
prepare_dataset(
|
||||||
|
url=URL_DEV_OTHER,
|
||||||
|
md5sum=MD5_DEV_OTHER,
|
||||||
|
target_dir=os.path.join(args.target_dir, "dev-other"),
|
||||||
|
manifest_path=args.manifest_prefix + ".dev-other")
|
||||||
|
prepare_dataset(
|
||||||
|
url=URL_TRAIN_CLEAN_360,
|
||||||
|
md5sum=MD5_TRAIN_CLEAN_360,
|
||||||
|
target_dir=os.path.join(args.target_dir, "train-clean-360"),
|
||||||
|
manifest_path=args.manifest_prefix + ".train-clean-360")
|
||||||
|
prepare_dataset(
|
||||||
|
url=URL_TRAIN_OTHER_500,
|
||||||
|
md5sum=MD5_TRAIN_OTHER_500,
|
||||||
|
target_dir=os.path.join(args.target_dir, "train-other-500"),
|
||||||
|
manifest_path=args.manifest_prefix + ".train-other-500")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -0,0 +1,128 @@
|
|||||||
|
"""Prepare CHiME3 background data.
|
||||||
|
|
||||||
|
Download, unpack and create manifest files.
|
||||||
|
Manifest file is a json-format file with each line containing the
|
||||||
|
meta data (i.e. audio filepath, transcript and audio duration)
|
||||||
|
of each audio file in the data set.
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import distutils.util
|
||||||
|
import os
|
||||||
|
import wget
|
||||||
|
import zipfile
|
||||||
|
import argparse
|
||||||
|
import soundfile
|
||||||
|
import json
|
||||||
|
from paddle.v2.dataset.common import md5file
|
||||||
|
|
||||||
|
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
|
||||||
|
|
||||||
|
URL = "https://d4s.myairbridge.com/packagev2/AG0Y3DNBE5IWRRTV/?dlid=W19XG7T0NNHB027139H0EQ"
|
||||||
|
MD5 = "c3ff512618d7a67d4f85566ea1bc39ec"
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--target_dir",
|
||||||
|
default=DATA_HOME + "/chime3_background",
|
||||||
|
type=str,
|
||||||
|
help="Directory to save the dataset. (default: %(default)s)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--manifest_filepath",
|
||||||
|
default="manifest.chime3.background",
|
||||||
|
type=str,
|
||||||
|
help="Filepath for output manifests. (default: %(default)s)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def download(url, md5sum, target_dir, filename=None):
|
||||||
|
"""Download file from url to target_dir, and check md5sum."""
|
||||||
|
if filename == None:
|
||||||
|
filename = url.split("/")[-1]
|
||||||
|
if not os.path.exists(target_dir): os.makedirs(target_dir)
|
||||||
|
filepath = os.path.join(target_dir, filename)
|
||||||
|
if not (os.path.exists(filepath) and md5file(filepath) == md5sum):
|
||||||
|
print("Downloading %s ..." % url)
|
||||||
|
wget.download(url, target_dir)
|
||||||
|
print("\nMD5 Chesksum %s ..." % filepath)
|
||||||
|
if not md5file(filepath) == md5sum:
|
||||||
|
raise RuntimeError("MD5 checksum failed.")
|
||||||
|
else:
|
||||||
|
print("File exists, skip downloading. (%s)" % filepath)
|
||||||
|
return filepath
|
||||||
|
|
||||||
|
|
||||||
|
def unpack(filepath, target_dir):
|
||||||
|
"""Unpack the file to the target_dir."""
|
||||||
|
print("Unpacking %s ..." % filepath)
|
||||||
|
if filepath.endswith('.zip'):
|
||||||
|
zip = zipfile.ZipFile(filepath, 'r')
|
||||||
|
zip.extractall(target_dir)
|
||||||
|
zip.close()
|
||||||
|
elif filepath.endswith('.tar') or filepath.endswith('.tar.gz'):
|
||||||
|
tar = zipfile.open(filepath)
|
||||||
|
tar.extractall(target_dir)
|
||||||
|
tar.close()
|
||||||
|
else:
|
||||||
|
raise ValueError("File format is not supported for unpacking.")
|
||||||
|
|
||||||
|
|
||||||
|
def create_manifest(data_dir, manifest_path):
|
||||||
|
"""Create a manifest json file summarizing the data set, with each line
|
||||||
|
containing the meta data (i.e. audio filepath, transcription text, audio
|
||||||
|
duration) of each audio file within the data set.
|
||||||
|
"""
|
||||||
|
print("Creating manifest %s ..." % manifest_path)
|
||||||
|
json_lines = []
|
||||||
|
for subfolder, _, filelist in sorted(os.walk(data_dir)):
|
||||||
|
for filename in filelist:
|
||||||
|
if filename.endswith('.wav'):
|
||||||
|
filepath = os.path.join(data_dir, subfolder, filename)
|
||||||
|
audio_data, samplerate = soundfile.read(filepath)
|
||||||
|
duration = float(len(audio_data)) / samplerate
|
||||||
|
json_lines.append(
|
||||||
|
json.dumps({
|
||||||
|
'audio_filepath': filepath,
|
||||||
|
'duration': duration,
|
||||||
|
'text': ''
|
||||||
|
}))
|
||||||
|
with open(manifest_path, 'w') as out_file:
|
||||||
|
for line in json_lines:
|
||||||
|
out_file.write(line + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_chime3(url, md5sum, target_dir, manifest_path):
|
||||||
|
"""Download, unpack and create summmary manifest file."""
|
||||||
|
if not os.path.exists(os.path.join(target_dir, "CHiME3")):
|
||||||
|
# download
|
||||||
|
filepath = download(url, md5sum, target_dir,
|
||||||
|
"myairbridge-AG0Y3DNBE5IWRRTV.zip")
|
||||||
|
# unpack
|
||||||
|
unpack(filepath, target_dir)
|
||||||
|
unpack(
|
||||||
|
os.path.join(target_dir, 'CHiME3_background_bus.zip'), target_dir)
|
||||||
|
unpack(
|
||||||
|
os.path.join(target_dir, 'CHiME3_background_caf.zip'), target_dir)
|
||||||
|
unpack(
|
||||||
|
os.path.join(target_dir, 'CHiME3_background_ped.zip'), target_dir)
|
||||||
|
unpack(
|
||||||
|
os.path.join(target_dir, 'CHiME3_background_str.zip'), target_dir)
|
||||||
|
else:
|
||||||
|
print("Skip downloading and unpacking. Data already exists in %s." %
|
||||||
|
target_dir)
|
||||||
|
# create manifest json file
|
||||||
|
create_manifest(target_dir, manifest_path)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
prepare_chime3(
|
||||||
|
url=URL,
|
||||||
|
md5sum=MD5,
|
||||||
|
target_dir=args.target_dir,
|
||||||
|
manifest_path=args.manifest_filepath)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -0,0 +1,16 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
# download data, generate manifests
|
||||||
|
PYTHONPATH=../../:$PYTHONPATH python voxforge.py \
|
||||||
|
--manifest_prefix='./manifest' \
|
||||||
|
--target_dir='~/.cache/paddle/dataset/speech/VoxForge' \
|
||||||
|
--is_merge_dialect=True \
|
||||||
|
--dialects 'american' 'british' 'australian' 'european' 'irish' 'canadian' 'indian'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Prepare VoxForge failed. Terminated."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "VoxForge Data preparation done."
|
||||||
|
exit 0
|
@ -0,0 +1,221 @@
|
|||||||
|
"""Prepare VoxForge dataset
|
||||||
|
|
||||||
|
Download, unpack and create manifest files.
|
||||||
|
Manifest file is a json-format file with each line containing the
|
||||||
|
meta data (i.e. audio filepath, transcript and audio duration)
|
||||||
|
of each audio file in the data set.
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import codecs
|
||||||
|
import soundfile
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
from data_utils.utility import download_multi, unpack, getfile_insensitive
|
||||||
|
|
||||||
|
DATA_HOME = '~/.cache/paddle/dataset/speech'
|
||||||
|
|
||||||
|
DATA_URL = 'http://www.repository.voxforge1.org/downloads/SpeechCorpus/Trunk/' \
|
||||||
|
'Audio/Main/16kHz_16bit'
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--target_dir",
|
||||||
|
default=DATA_HOME + "/VoxForge",
|
||||||
|
type=str,
|
||||||
|
help="Directory to save the dataset. (default: %(default)s)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dialects",
|
||||||
|
default=[
|
||||||
|
'american', 'british', 'australian', 'european', 'irish', 'canadian',
|
||||||
|
'indian'
|
||||||
|
],
|
||||||
|
nargs='+',
|
||||||
|
type=str,
|
||||||
|
help="Dialect types. (default: %(default)s)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--is_merge_dialect",
|
||||||
|
default=True,
|
||||||
|
type=bool,
|
||||||
|
help="If set True, manifests of american dialect and canadian dialect will "
|
||||||
|
"be merged to american-canadian dialect; manifests of british "
|
||||||
|
"dialect, irish dialect and australian dialect will be merged to "
|
||||||
|
"commonwealth dialect. (default: %(default)s)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--manifest_prefix",
|
||||||
|
default="manifest",
|
||||||
|
type=str,
|
||||||
|
help="Filepath prefix for output manifests. (default: %(default)s)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def download_and_unpack(target_dir, url):
|
||||||
|
wget_args = '-q -l 1 -N -nd -c -e robots=off -A tgz -r -np'
|
||||||
|
tgz_dir = os.path.join(target_dir, 'tgz')
|
||||||
|
exit_code = download_multi(url, tgz_dir, wget_args)
|
||||||
|
if exit_code != 0:
|
||||||
|
print('Download tgz audio files failed with exit code %d.' % exit_code)
|
||||||
|
else:
|
||||||
|
print('Download done, start unpacking ...')
|
||||||
|
audio_dir = os.path.join(target_dir, 'audio')
|
||||||
|
for root, dirs, files in os.walk(tgz_dir):
|
||||||
|
for file in files:
|
||||||
|
print(file)
|
||||||
|
if file.endswith('.tgz'):
|
||||||
|
unpack(os.path.join(root, file), audio_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def select_dialects(target_dir, dialect_list):
|
||||||
|
"""Classify audio files by dialect."""
|
||||||
|
dialect_root_dir = os.path.join(target_dir, 'dialect')
|
||||||
|
if os.path.exists(dialect_root_dir):
|
||||||
|
shutil.rmtree(dialect_root_dir)
|
||||||
|
os.mkdir(dialect_root_dir)
|
||||||
|
audio_dir = os.path.abspath(os.path.join(target_dir, 'audio'))
|
||||||
|
for dialect in dialect_list:
|
||||||
|
# filter files by dialect
|
||||||
|
command = 'find %s -iwholename "*etc/readme*" -exec egrep -iHl \
|
||||||
|
"pronunciation dialect.*%s" {} \;' % (audio_dir, dialect)
|
||||||
|
p = subprocess.Popen(
|
||||||
|
command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, shell=True)
|
||||||
|
output, err = p.communicate()
|
||||||
|
dialect_dir = os.path.join(dialect_root_dir, dialect)
|
||||||
|
if os.path.exists(dialect_dir):
|
||||||
|
shutil.rmtree(dialect_dir)
|
||||||
|
os.mkdir(dialect_dir)
|
||||||
|
for path in output.splitlines():
|
||||||
|
src_dir = os.path.dirname(os.path.dirname(path))
|
||||||
|
link = os.path.basename(os.path.normpath(src_dir))
|
||||||
|
os.symlink(src_dir, os.path.join(dialect_dir, link))
|
||||||
|
|
||||||
|
|
||||||
|
def generate_manifest(data_dir, manifest_path):
|
||||||
|
json_lines = []
|
||||||
|
|
||||||
|
for path in os.listdir(data_dir):
|
||||||
|
audio_link = os.path.join(data_dir, path)
|
||||||
|
assert os.path.islink(
|
||||||
|
audio_link), '%s should be symbolic link.' % audio_link
|
||||||
|
actual_audio_dir = os.path.abspath(os.readlink(audio_link))
|
||||||
|
|
||||||
|
audio_type = ''
|
||||||
|
if os.path.isdir(os.path.join(actual_audio_dir, 'wav')):
|
||||||
|
audio_type = 'wav'
|
||||||
|
elif os.path.isdir(os.path.join(actual_audio_dir, 'flac')):
|
||||||
|
audio_type = 'flac'
|
||||||
|
else:
|
||||||
|
print('Unknown audio type, skipped processing %s.' %
|
||||||
|
actual_audio_dir)
|
||||||
|
continue
|
||||||
|
|
||||||
|
etc_dir = os.path.join(actual_audio_dir, 'etc')
|
||||||
|
prompts_file = os.path.join(etc_dir, 'PROMPTS')
|
||||||
|
if not os.path.isfile(prompts_file):
|
||||||
|
print('PROMPTS file missing, skip processing %s.' %
|
||||||
|
actual_audio_dir)
|
||||||
|
continue
|
||||||
|
|
||||||
|
readme_file = getfile_insensitive(os.path.join(etc_dir, 'README'))
|
||||||
|
if readme_file is None:
|
||||||
|
print('README file missing, skip processing %s.' % actual_audio_dir)
|
||||||
|
continue
|
||||||
|
|
||||||
|
for line in file(prompts_file):
|
||||||
|
u, trans = line.strip().split(None, 1)
|
||||||
|
u_parts = u.split('/')
|
||||||
|
|
||||||
|
# try to format the date time
|
||||||
|
try:
|
||||||
|
speaker, date, sfx = u_parts[-3].split('-')
|
||||||
|
obj = datetime.datetime.strptime(date, '%y.%m.%d')
|
||||||
|
formatted = obj.strftime('%Y%m%d')
|
||||||
|
u_parts[-3] = '-'.join([speaker, formatted, sfx])
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if len(u_parts) < 2:
|
||||||
|
u_parts = [audio_type] + u_parts
|
||||||
|
u_parts[-2] = audio_type
|
||||||
|
u_parts[-1] += '.' + audio_type
|
||||||
|
u = os.path.join(actual_audio_dir, '/'.join(u_parts[-2:]))
|
||||||
|
|
||||||
|
if not os.path.isfile(u):
|
||||||
|
print('Audio file missing, skip processing %s.' % u)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if os.stat(u).st_size == 0:
|
||||||
|
print('Empty audio file, skip processing %s.' % u)
|
||||||
|
continue
|
||||||
|
|
||||||
|
trans = trans.strip().replace('-', ' ')
|
||||||
|
if not trans.isupper() or \
|
||||||
|
not trans.strip().replace(' ', '').replace("'", "").isalpha():
|
||||||
|
print("Transcript not normalized properly, skip processing %s."
|
||||||
|
% u)
|
||||||
|
continue
|
||||||
|
|
||||||
|
audio_data, samplerate = soundfile.read(u)
|
||||||
|
duration = float(len(audio_data)) / samplerate
|
||||||
|
json_lines.append(
|
||||||
|
json.dumps({
|
||||||
|
'audio_filepath': u,
|
||||||
|
'duration': duration,
|
||||||
|
'text': trans.lower()
|
||||||
|
}))
|
||||||
|
|
||||||
|
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
|
||||||
|
for line in json_lines:
|
||||||
|
fout.write(line + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
def merge_manifests(manifest_files, save_path):
|
||||||
|
lines = []
|
||||||
|
for manifest_file in manifest_files:
|
||||||
|
line = codecs.open(manifest_file, 'r', 'utf-8').readlines()
|
||||||
|
lines += line
|
||||||
|
|
||||||
|
with codecs.open(save_path, 'w', 'utf-8') as fout:
|
||||||
|
for line in lines:
|
||||||
|
fout.write(line)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_dataset(url, dialects, target_dir, manifest_prefix, is_merge):
|
||||||
|
download_and_unpack(target_dir, url)
|
||||||
|
select_dialects(target_dir, dialects)
|
||||||
|
american_canadian_manifests = []
|
||||||
|
commonwealth_manifests = []
|
||||||
|
for dialect in dialects:
|
||||||
|
dialect_dir = os.path.join(target_dir, 'dialect', dialect)
|
||||||
|
manifest_fpath = manifest_prefix + '.' + dialect
|
||||||
|
if dialect == 'american' or dialect == 'canadian':
|
||||||
|
american_canadian_manifests.append(manifest_fpath)
|
||||||
|
if dialect == 'australian' \
|
||||||
|
or dialect == 'british' \
|
||||||
|
or dialect == 'irish':
|
||||||
|
commonwealth_manifests.append(manifest_fpath)
|
||||||
|
generate_manifest(dialect_dir, manifest_fpath)
|
||||||
|
|
||||||
|
if is_merge:
|
||||||
|
if len(american_canadian_manifests) > 0:
|
||||||
|
manifest_fpath = manifest_prefix + '.american-canadian'
|
||||||
|
merge_manifests(american_canadian_manifests, manifest_fpath)
|
||||||
|
if len(commonwealth_manifests) > 0:
|
||||||
|
manifest_fpath = manifest_prefix + '.commonwealth'
|
||||||
|
merge_manifests(commonwealth_manifests, manifest_fpath)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if args.target_dir.startswith('~'):
|
||||||
|
args.target_dir = os.path.expanduser(args.target_dir)
|
||||||
|
|
||||||
|
prepare_dataset(DATA_URL, args.dialects, args.target_dir,
|
||||||
|
args.manifest_prefix, args.is_merge_dialect)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -0,0 +1,685 @@
|
|||||||
|
"""Contains the audio segment class."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import io
|
||||||
|
import struct
|
||||||
|
import re
|
||||||
|
import soundfile
|
||||||
|
import resampy
|
||||||
|
from scipy import signal
|
||||||
|
import random
|
||||||
|
import copy
|
||||||
|
|
||||||
|
|
||||||
|
class AudioSegment(object):
|
||||||
|
"""Monaural audio segment abstraction.
|
||||||
|
|
||||||
|
:param samples: Audio samples [num_samples x num_channels].
|
||||||
|
:type samples: ndarray.float32
|
||||||
|
:param sample_rate: Audio sample rate.
|
||||||
|
:type sample_rate: int
|
||||||
|
:raises TypeError: If the sample data type is not float or int.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, samples, sample_rate):
|
||||||
|
"""Create audio segment from samples.
|
||||||
|
|
||||||
|
Samples are convert float32 internally, with int scaled to [-1, 1].
|
||||||
|
"""
|
||||||
|
self._samples = self._convert_samples_to_float32(samples)
|
||||||
|
self._sample_rate = sample_rate
|
||||||
|
if self._samples.ndim >= 2:
|
||||||
|
self._samples = np.mean(self._samples, 1)
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
"""Return whether two objects are equal."""
|
||||||
|
if type(other) is not type(self):
|
||||||
|
return False
|
||||||
|
if self._sample_rate != other._sample_rate:
|
||||||
|
return False
|
||||||
|
if self._samples.shape != other._samples.shape:
|
||||||
|
return False
|
||||||
|
if np.any(self.samples != other._samples):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __ne__(self, other):
|
||||||
|
"""Return whether two objects are unequal."""
|
||||||
|
return not self.__eq__(other)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
"""Return human-readable representation of segment."""
|
||||||
|
return ("%s: num_samples=%d, sample_rate=%d, duration=%.2fsec, "
|
||||||
|
"rms=%.2fdB" % (type(self), self.num_samples, self.sample_rate,
|
||||||
|
self.duration, self.rms_db))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_file(cls, file):
|
||||||
|
"""Create audio segment from audio file.
|
||||||
|
|
||||||
|
:param filepath: Filepath or file object to audio file.
|
||||||
|
:type filepath: basestring|file
|
||||||
|
:return: Audio segment instance.
|
||||||
|
:rtype: AudioSegment
|
||||||
|
"""
|
||||||
|
if isinstance(file, basestring) and re.findall(r".seqbin_\d+$", file):
|
||||||
|
return cls.from_sequence_file(file)
|
||||||
|
else:
|
||||||
|
samples, sample_rate = soundfile.read(file, dtype='float32')
|
||||||
|
return cls(samples, sample_rate)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def slice_from_file(cls, file, start=None, end=None):
|
||||||
|
"""Loads a small section of an audio without having to load
|
||||||
|
the entire file into the memory which can be incredibly wasteful.
|
||||||
|
|
||||||
|
:param file: Input audio filepath or file object.
|
||||||
|
:type file: basestring|file
|
||||||
|
:param start: Start time in seconds. If start is negative, it wraps
|
||||||
|
around from the end. If not provided, this function
|
||||||
|
reads from the very beginning.
|
||||||
|
:type start: float
|
||||||
|
:param end: End time in seconds. If end is negative, it wraps around
|
||||||
|
from the end. If not provided, the default behvaior is
|
||||||
|
to read to the end of the file.
|
||||||
|
:type end: float
|
||||||
|
:return: AudioSegment instance of the specified slice of the input
|
||||||
|
audio file.
|
||||||
|
:rtype: AudioSegment
|
||||||
|
:raise ValueError: If start or end is incorrectly set, e.g. out of
|
||||||
|
bounds in time.
|
||||||
|
"""
|
||||||
|
sndfile = soundfile.SoundFile(file)
|
||||||
|
sample_rate = sndfile.samplerate
|
||||||
|
duration = float(len(sndfile)) / sample_rate
|
||||||
|
start = 0. if start is None else start
|
||||||
|
end = 0. if end is None else end
|
||||||
|
if start < 0.0:
|
||||||
|
start += duration
|
||||||
|
if end < 0.0:
|
||||||
|
end += duration
|
||||||
|
if start < 0.0:
|
||||||
|
raise ValueError("The slice start position (%f s) is out of "
|
||||||
|
"bounds." % start)
|
||||||
|
if end < 0.0:
|
||||||
|
raise ValueError("The slice end position (%f s) is out of bounds." %
|
||||||
|
end)
|
||||||
|
if start > end:
|
||||||
|
raise ValueError("The slice start position (%f s) is later than "
|
||||||
|
"the slice end position (%f s)." % (start, end))
|
||||||
|
if end > duration:
|
||||||
|
raise ValueError("The slice end position (%f s) is out of bounds "
|
||||||
|
"(> %f s)" % (end, duration))
|
||||||
|
start_frame = int(start * sample_rate)
|
||||||
|
end_frame = int(end * sample_rate)
|
||||||
|
sndfile.seek(start_frame)
|
||||||
|
data = sndfile.read(frames=end_frame - start_frame, dtype='float32')
|
||||||
|
return cls(data, sample_rate)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_sequence_file(cls, filepath):
|
||||||
|
"""Create audio segment from sequence file. Sequence file is a binary
|
||||||
|
file containing a collection of multiple audio files, with several
|
||||||
|
header bytes in the head indicating the offsets of each audio byte data
|
||||||
|
chunk.
|
||||||
|
|
||||||
|
The format is:
|
||||||
|
|
||||||
|
4 bytes (int, version),
|
||||||
|
4 bytes (int, num of utterance),
|
||||||
|
4 bytes (int, bytes per header),
|
||||||
|
[bytes_per_header*(num_utterance+1)] bytes (offsets for each audio),
|
||||||
|
audio_bytes_data_of_1st_utterance,
|
||||||
|
audio_bytes_data_of_2nd_utterance,
|
||||||
|
......
|
||||||
|
|
||||||
|
Sequence file name must end with ".seqbin". And the filename of the 5th
|
||||||
|
utterance's audio file in sequence file "xxx.seqbin" must be
|
||||||
|
"xxx.seqbin_5", with "5" indicating the utterance index within this
|
||||||
|
sequence file (starting from 1).
|
||||||
|
|
||||||
|
:param filepath: Filepath of sequence file.
|
||||||
|
:type filepath: basestring
|
||||||
|
:return: Audio segment instance.
|
||||||
|
:rtype: AudioSegment
|
||||||
|
"""
|
||||||
|
# parse filepath
|
||||||
|
matches = re.match(r"(.+\.seqbin)_(\d+)", filepath)
|
||||||
|
if matches is None:
|
||||||
|
raise IOError("File type of %s is not supported" % filepath)
|
||||||
|
filename = matches.group(1)
|
||||||
|
fileno = int(matches.group(2))
|
||||||
|
|
||||||
|
# read headers
|
||||||
|
f = open(filename, 'rb')
|
||||||
|
version = f.read(4)
|
||||||
|
num_utterances = struct.unpack("i", f.read(4))[0]
|
||||||
|
bytes_per_header = struct.unpack("i", f.read(4))[0]
|
||||||
|
header_bytes = f.read(bytes_per_header * (num_utterances + 1))
|
||||||
|
header = [
|
||||||
|
struct.unpack("i", header_bytes[bytes_per_header * i:
|
||||||
|
bytes_per_header * (i + 1)])[0]
|
||||||
|
for i in range(num_utterances + 1)
|
||||||
|
]
|
||||||
|
|
||||||
|
# read audio bytes
|
||||||
|
f.seek(header[fileno - 1])
|
||||||
|
audio_bytes = f.read(header[fileno] - header[fileno - 1])
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
# create audio segment
|
||||||
|
try:
|
||||||
|
return cls.from_bytes(audio_bytes)
|
||||||
|
except Exception as e:
|
||||||
|
samples = np.frombuffer(audio_bytes, dtype='int16')
|
||||||
|
return cls(samples=samples, sample_rate=8000)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytes(cls, bytes):
|
||||||
|
"""Create audio segment from a byte string containing audio samples.
|
||||||
|
|
||||||
|
:param bytes: Byte string containing audio samples.
|
||||||
|
:type bytes: str
|
||||||
|
:return: Audio segment instance.
|
||||||
|
:rtype: AudioSegment
|
||||||
|
"""
|
||||||
|
samples, sample_rate = soundfile.read(
|
||||||
|
io.BytesIO(bytes), dtype='float32')
|
||||||
|
return cls(samples, sample_rate)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def concatenate(cls, *segments):
|
||||||
|
"""Concatenate an arbitrary number of audio segments together.
|
||||||
|
|
||||||
|
:param *segments: Input audio segments to be concatenated.
|
||||||
|
:type *segments: tuple of AudioSegment
|
||||||
|
:return: Audio segment instance as concatenating results.
|
||||||
|
:rtype: AudioSegment
|
||||||
|
:raises ValueError: If the number of segments is zero, or if the
|
||||||
|
sample_rate of any segments does not match.
|
||||||
|
:raises TypeError: If any segment is not AudioSegment instance.
|
||||||
|
"""
|
||||||
|
# Perform basic sanity-checks.
|
||||||
|
if len(segments) == 0:
|
||||||
|
raise ValueError("No audio segments are given to concatenate.")
|
||||||
|
sample_rate = segments[0]._sample_rate
|
||||||
|
for seg in segments:
|
||||||
|
if sample_rate != seg._sample_rate:
|
||||||
|
raise ValueError("Can't concatenate segments with "
|
||||||
|
"different sample rates")
|
||||||
|
if type(seg) is not cls:
|
||||||
|
raise TypeError("Only audio segments of the same type "
|
||||||
|
"can be concatenated.")
|
||||||
|
samples = np.concatenate([seg.samples for seg in segments])
|
||||||
|
return cls(samples, sample_rate)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make_silence(cls, duration, sample_rate):
|
||||||
|
"""Creates a silent audio segment of the given duration and sample rate.
|
||||||
|
|
||||||
|
:param duration: Length of silence in seconds.
|
||||||
|
:type duration: float
|
||||||
|
:param sample_rate: Sample rate.
|
||||||
|
:type sample_rate: float
|
||||||
|
:return: Silent AudioSegment instance of the given duration.
|
||||||
|
:rtype: AudioSegment
|
||||||
|
"""
|
||||||
|
samples = np.zeros(int(duration * sample_rate))
|
||||||
|
return cls(samples, sample_rate)
|
||||||
|
|
||||||
|
def to_wav_file(self, filepath, dtype='float32'):
|
||||||
|
"""Save audio segment to disk as wav file.
|
||||||
|
|
||||||
|
:param filepath: WAV filepath or file object to save the
|
||||||
|
audio segment.
|
||||||
|
:type filepath: basestring|file
|
||||||
|
:param dtype: Subtype for audio file. Options: 'int16', 'int32',
|
||||||
|
'float32', 'float64'. Default is 'float32'.
|
||||||
|
:type dtype: str
|
||||||
|
:raises TypeError: If dtype is not supported.
|
||||||
|
"""
|
||||||
|
samples = self._convert_samples_from_float32(self._samples, dtype)
|
||||||
|
subtype_map = {
|
||||||
|
'int16': 'PCM_16',
|
||||||
|
'int32': 'PCM_32',
|
||||||
|
'float32': 'FLOAT',
|
||||||
|
'float64': 'DOUBLE'
|
||||||
|
}
|
||||||
|
soundfile.write(
|
||||||
|
filepath,
|
||||||
|
samples,
|
||||||
|
self._sample_rate,
|
||||||
|
format='WAV',
|
||||||
|
subtype=subtype_map[dtype])
|
||||||
|
|
||||||
|
def superimpose(self, other):
|
||||||
|
"""Add samples from another segment to those of this segment
|
||||||
|
(sample-wise addition, not segment concatenation).
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param other: Segment containing samples to be added in.
|
||||||
|
:type other: AudioSegments
|
||||||
|
:raise TypeError: If type of two segments don't match.
|
||||||
|
:raise ValueError: If the sample rates of the two segments are not
|
||||||
|
equal, or if the lengths of segments don't match.
|
||||||
|
"""
|
||||||
|
if isinstance(other, type(self)):
|
||||||
|
raise TypeError("Cannot add segments of different types: %s "
|
||||||
|
"and %s." % (type(self), type(other)))
|
||||||
|
if self._sample_rate != other._sample_rate:
|
||||||
|
raise ValueError("Sample rates must match to add segments.")
|
||||||
|
if len(self._samples) != len(other._samples):
|
||||||
|
raise ValueError("Segment lengths must match to add segments.")
|
||||||
|
self._samples += other._samples
|
||||||
|
|
||||||
|
def to_bytes(self, dtype='float32'):
|
||||||
|
"""Create a byte string containing the audio content.
|
||||||
|
|
||||||
|
:param dtype: Data type for export samples. Options: 'int16', 'int32',
|
||||||
|
'float32', 'float64'. Default is 'float32'.
|
||||||
|
:type dtype: str
|
||||||
|
:return: Byte string containing audio content.
|
||||||
|
:rtype: str
|
||||||
|
"""
|
||||||
|
samples = self._convert_samples_from_float32(self._samples, dtype)
|
||||||
|
return samples.tostring()
|
||||||
|
|
||||||
|
def gain_db(self, gain):
|
||||||
|
"""Apply gain in decibels to samples.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param gain: Gain in decibels to apply to samples.
|
||||||
|
:type gain: float|1darray
|
||||||
|
"""
|
||||||
|
self._samples *= 10.**(gain / 20.)
|
||||||
|
|
||||||
|
def change_speed(self, speed_rate):
|
||||||
|
"""Change the audio speed by linear interpolation.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param speed_rate: Rate of speed change:
|
||||||
|
speed_rate > 1.0, speed up the audio;
|
||||||
|
speed_rate = 1.0, unchanged;
|
||||||
|
speed_rate < 1.0, slow down the audio;
|
||||||
|
speed_rate <= 0.0, not allowed, raise ValueError.
|
||||||
|
:type speed_rate: float
|
||||||
|
:raises ValueError: If speed_rate <= 0.0.
|
||||||
|
"""
|
||||||
|
if speed_rate <= 0:
|
||||||
|
raise ValueError("speed_rate should be greater than zero.")
|
||||||
|
old_length = self._samples.shape[0]
|
||||||
|
new_length = int(old_length / speed_rate)
|
||||||
|
old_indices = np.arange(old_length)
|
||||||
|
new_indices = np.linspace(start=0, stop=old_length, num=new_length)
|
||||||
|
self._samples = np.interp(new_indices, old_indices, self._samples)
|
||||||
|
|
||||||
|
def normalize(self, target_db=-20, max_gain_db=300.0):
|
||||||
|
"""Normalize audio to be of the desired RMS value in decibels.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param target_db: Target RMS value in decibels. This value should be
|
||||||
|
less than 0.0 as 0.0 is full-scale audio.
|
||||||
|
:type target_db: float
|
||||||
|
:param max_gain_db: Max amount of gain in dB that can be applied for
|
||||||
|
normalization. This is to prevent nans when
|
||||||
|
attempting to normalize a signal consisting of
|
||||||
|
all zeros.
|
||||||
|
:type max_gain_db: float
|
||||||
|
:raises ValueError: If the required gain to normalize the segment to
|
||||||
|
the target_db value exceeds max_gain_db.
|
||||||
|
"""
|
||||||
|
gain = target_db - self.rms_db
|
||||||
|
if gain > max_gain_db:
|
||||||
|
raise ValueError(
|
||||||
|
"Unable to normalize segment to %f dB because the "
|
||||||
|
"the probable gain have exceeds max_gain_db (%f dB)" %
|
||||||
|
(target_db, max_gain_db))
|
||||||
|
self.gain_db(min(max_gain_db, target_db - self.rms_db))
|
||||||
|
|
||||||
|
def normalize_online_bayesian(self,
|
||||||
|
target_db,
|
||||||
|
prior_db,
|
||||||
|
prior_samples,
|
||||||
|
startup_delay=0.0):
|
||||||
|
"""Normalize audio using a production-compatible online/causal
|
||||||
|
algorithm. This uses an exponential likelihood and gamma prior to
|
||||||
|
make online estimates of the RMS even when there are very few samples.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param target_db: Target RMS value in decibels.
|
||||||
|
:type target_bd: float
|
||||||
|
:param prior_db: Prior RMS estimate in decibels.
|
||||||
|
:type prior_db: float
|
||||||
|
:param prior_samples: Prior strength in number of samples.
|
||||||
|
:type prior_samples: float
|
||||||
|
:param startup_delay: Default 0.0s. If provided, this function will
|
||||||
|
accrue statistics for the first startup_delay
|
||||||
|
seconds before applying online normalization.
|
||||||
|
:type startup_delay: float
|
||||||
|
"""
|
||||||
|
# Estimate total RMS online.
|
||||||
|
startup_sample_idx = min(self.num_samples - 1,
|
||||||
|
int(self.sample_rate * startup_delay))
|
||||||
|
prior_mean_squared = 10.**(prior_db / 10.)
|
||||||
|
prior_sum_of_squares = prior_mean_squared * prior_samples
|
||||||
|
cumsum_of_squares = np.cumsum(self.samples**2)
|
||||||
|
sample_count = np.arange(self.num_samples) + 1
|
||||||
|
if startup_sample_idx > 0:
|
||||||
|
cumsum_of_squares[:startup_sample_idx] = \
|
||||||
|
cumsum_of_squares[startup_sample_idx]
|
||||||
|
sample_count[:startup_sample_idx] = \
|
||||||
|
sample_count[startup_sample_idx]
|
||||||
|
mean_squared_estimate = ((cumsum_of_squares + prior_sum_of_squares) /
|
||||||
|
(sample_count + prior_samples))
|
||||||
|
rms_estimate_db = 10 * np.log10(mean_squared_estimate)
|
||||||
|
# Compute required time-varying gain.
|
||||||
|
gain_db = target_db - rms_estimate_db
|
||||||
|
self.gain_db(gain_db)
|
||||||
|
|
||||||
|
def resample(self, target_sample_rate, filter='kaiser_best'):
|
||||||
|
"""Resample the audio to a target sample rate.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param target_sample_rate: Target sample rate.
|
||||||
|
:type target_sample_rate: int
|
||||||
|
:param filter: The resampling filter to use one of {'kaiser_best',
|
||||||
|
'kaiser_fast'}.
|
||||||
|
:type filter: str
|
||||||
|
"""
|
||||||
|
self._samples = resampy.resample(
|
||||||
|
self.samples, self.sample_rate, target_sample_rate, filter=filter)
|
||||||
|
self._sample_rate = target_sample_rate
|
||||||
|
|
||||||
|
def pad_silence(self, duration, sides='both'):
|
||||||
|
"""Pad this audio sample with a period of silence.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param duration: Length of silence in seconds to pad.
|
||||||
|
:type duration: float
|
||||||
|
:param sides: Position for padding:
|
||||||
|
'beginning' - adds silence in the beginning;
|
||||||
|
'end' - adds silence in the end;
|
||||||
|
'both' - adds silence in both the beginning and the end.
|
||||||
|
:type sides: str
|
||||||
|
:raises ValueError: If sides is not supported.
|
||||||
|
"""
|
||||||
|
if duration == 0.0:
|
||||||
|
return self
|
||||||
|
cls = type(self)
|
||||||
|
silence = self.make_silence(duration, self._sample_rate)
|
||||||
|
if sides == "beginning":
|
||||||
|
padded = cls.concatenate(silence, self)
|
||||||
|
elif sides == "end":
|
||||||
|
padded = cls.concatenate(self, silence)
|
||||||
|
elif sides == "both":
|
||||||
|
padded = cls.concatenate(silence, self, silence)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown value for the sides %s" % sides)
|
||||||
|
self._samples = padded._samples
|
||||||
|
|
||||||
|
def shift(self, shift_ms):
|
||||||
|
"""Shift the audio in time. If `shift_ms` is positive, shift with time
|
||||||
|
advance; if negative, shift with time delay. Silence are padded to
|
||||||
|
keep the duration unchanged.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param shift_ms: Shift time in millseconds. If positive, shift with
|
||||||
|
time advance; if negative; shift with time delay.
|
||||||
|
:type shift_ms: float
|
||||||
|
:raises ValueError: If shift_ms is longer than audio duration.
|
||||||
|
"""
|
||||||
|
if abs(shift_ms) / 1000.0 > self.duration:
|
||||||
|
raise ValueError("Absolute value of shift_ms should be smaller "
|
||||||
|
"than audio duration.")
|
||||||
|
shift_samples = int(shift_ms * self._sample_rate / 1000)
|
||||||
|
if shift_samples > 0:
|
||||||
|
# time advance
|
||||||
|
self._samples[:-shift_samples] = self._samples[shift_samples:]
|
||||||
|
self._samples[-shift_samples:] = 0
|
||||||
|
elif shift_samples < 0:
|
||||||
|
# time delay
|
||||||
|
self._samples[-shift_samples:] = self._samples[:shift_samples]
|
||||||
|
self._samples[:-shift_samples] = 0
|
||||||
|
|
||||||
|
def subsegment(self, start_sec=None, end_sec=None):
|
||||||
|
"""Cut the AudioSegment between given boundaries.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param start_sec: Beginning of subsegment in seconds.
|
||||||
|
:type start_sec: float
|
||||||
|
:param end_sec: End of subsegment in seconds.
|
||||||
|
:type end_sec: float
|
||||||
|
:raise ValueError: If start_sec or end_sec is incorrectly set, e.g. out
|
||||||
|
of bounds in time.
|
||||||
|
"""
|
||||||
|
start_sec = 0.0 if start_sec is None else start_sec
|
||||||
|
end_sec = self.duration if end_sec is None else end_sec
|
||||||
|
if start_sec < 0.0:
|
||||||
|
start_sec = self.duration + start_sec
|
||||||
|
if end_sec < 0.0:
|
||||||
|
end_sec = self.duration + end_sec
|
||||||
|
if start_sec < 0.0:
|
||||||
|
raise ValueError("The slice start position (%f s) is out of "
|
||||||
|
"bounds." % start_sec)
|
||||||
|
if end_sec < 0.0:
|
||||||
|
raise ValueError("The slice end position (%f s) is out of bounds." %
|
||||||
|
end_sec)
|
||||||
|
if start_sec > end_sec:
|
||||||
|
raise ValueError("The slice start position (%f s) is later than "
|
||||||
|
"the end position (%f s)." % (start_sec, end_sec))
|
||||||
|
if end_sec > self.duration:
|
||||||
|
raise ValueError("The slice end position (%f s) is out of bounds "
|
||||||
|
"(> %f s)" % (end_sec, self.duration))
|
||||||
|
start_sample = int(round(start_sec * self._sample_rate))
|
||||||
|
end_sample = int(round(end_sec * self._sample_rate))
|
||||||
|
self._samples = self._samples[start_sample:end_sample]
|
||||||
|
|
||||||
|
def random_subsegment(self, subsegment_length, rng=None):
|
||||||
|
"""Cut the specified length of the audiosegment randomly.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param subsegment_length: Subsegment length in seconds.
|
||||||
|
:type subsegment_length: float
|
||||||
|
:param rng: Random number generator state.
|
||||||
|
:type rng: random.Random
|
||||||
|
:raises ValueError: If the length of subsegment is greater than
|
||||||
|
the origineal segemnt.
|
||||||
|
"""
|
||||||
|
rng = random.Random() if rng is None else rng
|
||||||
|
if subsegment_length > self.duration:
|
||||||
|
raise ValueError("Length of subsegment must not be greater "
|
||||||
|
"than original segment.")
|
||||||
|
start_time = rng.uniform(0.0, self.duration - subsegment_length)
|
||||||
|
self.subsegment(start_time, start_time + subsegment_length)
|
||||||
|
|
||||||
|
def convolve(self, impulse_segment, allow_resample=False):
|
||||||
|
"""Convolve this audio segment with the given impulse segment.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param impulse_segment: Impulse response segments.
|
||||||
|
:type impulse_segment: AudioSegment
|
||||||
|
:param allow_resample: Indicates whether resampling is allowed when
|
||||||
|
the impulse_segment has a different sample
|
||||||
|
rate from this signal.
|
||||||
|
:type allow_resample: bool
|
||||||
|
:raises ValueError: If the sample rate is not match between two
|
||||||
|
audio segments when resample is not allowed.
|
||||||
|
"""
|
||||||
|
if allow_resample and self.sample_rate != impulse_segment.sample_rate:
|
||||||
|
impulse_segment.resample(self.sample_rate)
|
||||||
|
if self.sample_rate != impulse_segment.sample_rate:
|
||||||
|
raise ValueError("Impulse segment's sample rate (%d Hz) is not "
|
||||||
|
"equal to base signal sample rate (%d Hz)." %
|
||||||
|
(impulse_segment.sample_rate, self.sample_rate))
|
||||||
|
samples = signal.fftconvolve(self.samples, impulse_segment.samples,
|
||||||
|
"full")
|
||||||
|
self._samples = samples
|
||||||
|
|
||||||
|
def convolve_and_normalize(self, impulse_segment, allow_resample=False):
|
||||||
|
"""Convolve and normalize the resulting audio segment so that it
|
||||||
|
has the same average power as the input signal.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param impulse_segment: Impulse response segments.
|
||||||
|
:type impulse_segment: AudioSegment
|
||||||
|
:param allow_resample: Indicates whether resampling is allowed when
|
||||||
|
the impulse_segment has a different sample
|
||||||
|
rate from this signal.
|
||||||
|
:type allow_resample: bool
|
||||||
|
"""
|
||||||
|
target_db = self.rms_db
|
||||||
|
self.convolve(impulse_segment, allow_resample=allow_resample)
|
||||||
|
self.normalize(target_db)
|
||||||
|
|
||||||
|
def add_noise(self,
|
||||||
|
noise,
|
||||||
|
snr_dB,
|
||||||
|
allow_downsampling=False,
|
||||||
|
max_gain_db=300.0,
|
||||||
|
rng=None):
|
||||||
|
"""Add the given noise segment at a specific signal-to-noise ratio.
|
||||||
|
If the noise segment is longer than this segment, a random subsegment
|
||||||
|
of matching length is sampled from it and used instead.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param noise: Noise signal to add.
|
||||||
|
:type noise: AudioSegment
|
||||||
|
:param snr_dB: Signal-to-Noise Ratio, in decibels.
|
||||||
|
:type snr_dB: float
|
||||||
|
:param allow_downsampling: Whether to allow the noise signal to be
|
||||||
|
downsampled to match the base signal sample
|
||||||
|
rate.
|
||||||
|
:type allow_downsampling: bool
|
||||||
|
:param max_gain_db: Maximum amount of gain to apply to noise signal
|
||||||
|
before adding it in. This is to prevent attempting
|
||||||
|
to apply infinite gain to a zero signal.
|
||||||
|
:type max_gain_db: float
|
||||||
|
:param rng: Random number generator state.
|
||||||
|
:type rng: None|random.Random
|
||||||
|
:raises ValueError: If the sample rate does not match between the two
|
||||||
|
audio segments when downsampling is not allowed, or
|
||||||
|
if the duration of noise segments is shorter than
|
||||||
|
original audio segments.
|
||||||
|
"""
|
||||||
|
rng = random.Random() if rng is None else rng
|
||||||
|
if allow_downsampling and noise.sample_rate > self.sample_rate:
|
||||||
|
noise = noise.resample(self.sample_rate)
|
||||||
|
if noise.sample_rate != self.sample_rate:
|
||||||
|
raise ValueError("Noise sample rate (%d Hz) is not equal to base "
|
||||||
|
"signal sample rate (%d Hz)." % (noise.sample_rate,
|
||||||
|
self.sample_rate))
|
||||||
|
if noise.duration < self.duration:
|
||||||
|
raise ValueError("Noise signal (%f sec) must be at least as long as"
|
||||||
|
" base signal (%f sec)." %
|
||||||
|
(noise.duration, self.duration))
|
||||||
|
noise_gain_db = min(self.rms_db - noise.rms_db - snr_dB, max_gain_db)
|
||||||
|
noise_new = copy.deepcopy(noise)
|
||||||
|
noise_new.random_subsegment(self.duration, rng=rng)
|
||||||
|
noise_new.gain_db(noise_gain_db)
|
||||||
|
self.superimpose(noise_new)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def samples(self):
|
||||||
|
"""Return audio samples.
|
||||||
|
|
||||||
|
:return: Audio samples.
|
||||||
|
:rtype: ndarray
|
||||||
|
"""
|
||||||
|
return self._samples.copy()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sample_rate(self):
|
||||||
|
"""Return audio sample rate.
|
||||||
|
|
||||||
|
:return: Audio sample rate.
|
||||||
|
:rtype: int
|
||||||
|
"""
|
||||||
|
return self._sample_rate
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_samples(self):
|
||||||
|
"""Return number of samples.
|
||||||
|
|
||||||
|
:return: Number of samples.
|
||||||
|
:rtype: int
|
||||||
|
"""
|
||||||
|
return self._samples.shape[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def duration(self):
|
||||||
|
"""Return audio duration.
|
||||||
|
|
||||||
|
:return: Audio duration in seconds.
|
||||||
|
:rtype: float
|
||||||
|
"""
|
||||||
|
return self._samples.shape[0] / float(self._sample_rate)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rms_db(self):
|
||||||
|
"""Return root mean square energy of the audio in decibels.
|
||||||
|
|
||||||
|
:return: Root mean square energy in decibels.
|
||||||
|
:rtype: float
|
||||||
|
"""
|
||||||
|
# square root => multiply by 10 instead of 20 for dBs
|
||||||
|
mean_square = np.mean(self._samples**2)
|
||||||
|
return 10 * np.log10(mean_square)
|
||||||
|
|
||||||
|
def _convert_samples_to_float32(self, samples):
|
||||||
|
"""Convert sample type to float32.
|
||||||
|
|
||||||
|
Audio sample type is usually integer or float-point.
|
||||||
|
Integers will be scaled to [-1, 1] in float32.
|
||||||
|
"""
|
||||||
|
float32_samples = samples.astype('float32')
|
||||||
|
if samples.dtype in np.sctypes['int']:
|
||||||
|
bits = np.iinfo(samples.dtype).bits
|
||||||
|
float32_samples *= (1. / 2**(bits - 1))
|
||||||
|
elif samples.dtype in np.sctypes['float']:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise TypeError("Unsupported sample type: %s." % samples.dtype)
|
||||||
|
return float32_samples
|
||||||
|
|
||||||
|
def _convert_samples_from_float32(self, samples, dtype):
|
||||||
|
"""Convert sample type from float32 to dtype.
|
||||||
|
|
||||||
|
Audio sample type is usually integer or float-point. For integer
|
||||||
|
type, float32 will be rescaled from [-1, 1] to the maximum range
|
||||||
|
supported by the integer type.
|
||||||
|
|
||||||
|
This is for writing a audio file.
|
||||||
|
"""
|
||||||
|
dtype = np.dtype(dtype)
|
||||||
|
output_samples = samples.copy()
|
||||||
|
if dtype in np.sctypes['int']:
|
||||||
|
bits = np.iinfo(dtype).bits
|
||||||
|
output_samples *= (2**(bits - 1) / 1.)
|
||||||
|
min_val = np.iinfo(dtype).min
|
||||||
|
max_val = np.iinfo(dtype).max
|
||||||
|
output_samples[output_samples > max_val] = max_val
|
||||||
|
output_samples[output_samples < min_val] = min_val
|
||||||
|
elif samples.dtype in np.sctypes['float']:
|
||||||
|
min_val = np.finfo(dtype).min
|
||||||
|
max_val = np.finfo(dtype).max
|
||||||
|
output_samples[output_samples > max_val] = max_val
|
||||||
|
output_samples[output_samples < min_val] = min_val
|
||||||
|
else:
|
||||||
|
raise TypeError("Unsupported sample type: %s." % samples.dtype)
|
||||||
|
return output_samples.astype(dtype)
|
@ -0,0 +1,124 @@
|
|||||||
|
"""Contains the data augmentation pipeline."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor
|
||||||
|
from data_utils.augmentor.shift_perturb import ShiftPerturbAugmentor
|
||||||
|
from data_utils.augmentor.speed_perturb import SpeedPerturbAugmentor
|
||||||
|
from data_utils.augmentor.noise_perturb import NoisePerturbAugmentor
|
||||||
|
from data_utils.augmentor.impulse_response import ImpulseResponseAugmentor
|
||||||
|
from data_utils.augmentor.resample import ResampleAugmentor
|
||||||
|
from data_utils.augmentor.online_bayesian_normalization import \
|
||||||
|
OnlineBayesianNormalizationAugmentor
|
||||||
|
|
||||||
|
|
||||||
|
class AugmentationPipeline(object):
|
||||||
|
"""Build a pre-processing pipeline with various augmentation models.Such a
|
||||||
|
data augmentation pipeline is oftern leveraged to augment the training
|
||||||
|
samples to make the model invariant to certain types of perturbations in the
|
||||||
|
real world, improving model's generalization ability.
|
||||||
|
|
||||||
|
The pipeline is built according the the augmentation configuration in json
|
||||||
|
string, e.g.
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
[ {
|
||||||
|
"type": "noise",
|
||||||
|
"params": {"min_snr_dB": 10,
|
||||||
|
"max_snr_dB": 20,
|
||||||
|
"noise_manifest_path": "datasets/manifest.noise"},
|
||||||
|
"prob": 0.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "speed",
|
||||||
|
"params": {"min_speed_rate": 0.9,
|
||||||
|
"max_speed_rate": 1.1},
|
||||||
|
"prob": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "shift",
|
||||||
|
"params": {"min_shift_ms": -5,
|
||||||
|
"max_shift_ms": 5},
|
||||||
|
"prob": 1.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "volume",
|
||||||
|
"params": {"min_gain_dBFS": -10,
|
||||||
|
"max_gain_dBFS": 10},
|
||||||
|
"prob": 0.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "bayesian_normal",
|
||||||
|
"params": {"target_db": -20,
|
||||||
|
"prior_db": -20,
|
||||||
|
"prior_samples": 100},
|
||||||
|
"prob": 0.0
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
This augmentation configuration inserts two augmentation models
|
||||||
|
into the pipeline, with one is VolumePerturbAugmentor and the other
|
||||||
|
SpeedPerturbAugmentor. "prob" indicates the probability of the current
|
||||||
|
augmentor to take effect. If "prob" is zero, the augmentor does not take
|
||||||
|
effect.
|
||||||
|
|
||||||
|
:param augmentation_config: Augmentation configuration in json string.
|
||||||
|
:type augmentation_config: str
|
||||||
|
:param random_seed: Random seed.
|
||||||
|
:type random_seed: int
|
||||||
|
:raises ValueError: If the augmentation json config is in incorrect format".
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, augmentation_config, random_seed=0):
|
||||||
|
self._rng = random.Random(random_seed)
|
||||||
|
self._augmentors, self._rates = self._parse_pipeline_from(
|
||||||
|
augmentation_config)
|
||||||
|
|
||||||
|
def transform_audio(self, audio_segment):
|
||||||
|
"""Run the pre-processing pipeline for data augmentation.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param audio_segment: Audio segment to process.
|
||||||
|
:type audio_segment: AudioSegmenet|SpeechSegment
|
||||||
|
"""
|
||||||
|
for augmentor, rate in zip(self._augmentors, self._rates):
|
||||||
|
if self._rng.uniform(0., 1.) < rate:
|
||||||
|
augmentor.transform_audio(audio_segment)
|
||||||
|
|
||||||
|
def _parse_pipeline_from(self, config_json):
|
||||||
|
"""Parse the config json to build a augmentation pipelien."""
|
||||||
|
try:
|
||||||
|
configs = json.loads(config_json)
|
||||||
|
augmentors = [
|
||||||
|
self._get_augmentor(config["type"], config["params"])
|
||||||
|
for config in configs
|
||||||
|
]
|
||||||
|
rates = [config["prob"] for config in configs]
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError("Failed to parse the augmentation config json: "
|
||||||
|
"%s" % str(e))
|
||||||
|
return augmentors, rates
|
||||||
|
|
||||||
|
def _get_augmentor(self, augmentor_type, params):
|
||||||
|
"""Return an augmentation model by the type name, and pass in params."""
|
||||||
|
if augmentor_type == "volume":
|
||||||
|
return VolumePerturbAugmentor(self._rng, **params)
|
||||||
|
elif augmentor_type == "shift":
|
||||||
|
return ShiftPerturbAugmentor(self._rng, **params)
|
||||||
|
elif augmentor_type == "speed":
|
||||||
|
return SpeedPerturbAugmentor(self._rng, **params)
|
||||||
|
elif augmentor_type == "resample":
|
||||||
|
return ResampleAugmentor(self._rng, **params)
|
||||||
|
elif augmentor_type == "bayesian_normal":
|
||||||
|
return OnlineBayesianNormalizationAugmentor(self._rng, **params)
|
||||||
|
elif augmentor_type == "noise":
|
||||||
|
return NoisePerturbAugmentor(self._rng, **params)
|
||||||
|
elif augmentor_type == "impulse":
|
||||||
|
return ImpulseResponseAugmentor(self._rng, **params)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown augmentor type [%s]." % augmentor_type)
|
@ -0,0 +1,33 @@
|
|||||||
|
"""Contains the abstract base class for augmentation models."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class AugmentorBase(object):
|
||||||
|
"""Abstract base class for augmentation model (augmentor) class.
|
||||||
|
All augmentor classes should inherit from this class, and implement the
|
||||||
|
following abstract methods.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__metaclass__ = ABCMeta
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def transform_audio(self, audio_segment):
|
||||||
|
"""Adds various effects to the input audio segment. Such effects
|
||||||
|
will augment the training data to make the model invariant to certain
|
||||||
|
types of perturbations in the real world, improving model's
|
||||||
|
generalization ability.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param audio_segment: Audio segment to add effects to.
|
||||||
|
:type audio_segment: AudioSegmenet|SpeechSegment
|
||||||
|
"""
|
||||||
|
pass
|
@ -0,0 +1,34 @@
|
|||||||
|
"""Contains the impulse response augmentation model."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from data_utils.augmentor.base import AugmentorBase
|
||||||
|
from data_utils.utility import read_manifest
|
||||||
|
from data_utils.audio import AudioSegment
|
||||||
|
|
||||||
|
|
||||||
|
class ImpulseResponseAugmentor(AugmentorBase):
|
||||||
|
"""Augmentation model for adding impulse response effect.
|
||||||
|
|
||||||
|
:param rng: Random generator object.
|
||||||
|
:type rng: random.Random
|
||||||
|
:param impulse_manifest_path: Manifest path for impulse audio data.
|
||||||
|
:type impulse_manifest_path: basestring
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, rng, impulse_manifest_path):
|
||||||
|
self._rng = rng
|
||||||
|
self._impulse_manifest = read_manifest(impulse_manifest_path)
|
||||||
|
|
||||||
|
def transform_audio(self, audio_segment):
|
||||||
|
"""Add impulse response effect.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param audio_segment: Audio segment to add effects to.
|
||||||
|
:type audio_segment: AudioSegmenet|SpeechSegment
|
||||||
|
"""
|
||||||
|
impulse_json = self._rng.sample(self._impulse_manifest, 1)[0]
|
||||||
|
impulse_segment = AudioSegment.from_file(impulse_json['audio_filepath'])
|
||||||
|
audio_segment.convolve(impulse_segment, allow_resample=True)
|
@ -0,0 +1,49 @@
|
|||||||
|
"""Contains the noise perturb augmentation model."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from data_utils.augmentor.base import AugmentorBase
|
||||||
|
from data_utils.utility import read_manifest
|
||||||
|
from data_utils.audio import AudioSegment
|
||||||
|
|
||||||
|
|
||||||
|
class NoisePerturbAugmentor(AugmentorBase):
|
||||||
|
"""Augmentation model for adding background noise.
|
||||||
|
|
||||||
|
:param rng: Random generator object.
|
||||||
|
:type rng: random.Random
|
||||||
|
:param min_snr_dB: Minimal signal noise ratio, in decibels.
|
||||||
|
:type min_snr_dB: float
|
||||||
|
:param max_snr_dB: Maximal signal noise ratio, in decibels.
|
||||||
|
:type max_snr_dB: float
|
||||||
|
:param noise_manifest_path: Manifest path for noise audio data.
|
||||||
|
:type noise_manifest_path: basestring
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, rng, min_snr_dB, max_snr_dB, noise_manifest_path):
|
||||||
|
self._min_snr_dB = min_snr_dB
|
||||||
|
self._max_snr_dB = max_snr_dB
|
||||||
|
self._rng = rng
|
||||||
|
self._noise_manifest = read_manifest(manifest_path=noise_manifest_path)
|
||||||
|
|
||||||
|
def transform_audio(self, audio_segment):
|
||||||
|
"""Add background noise audio.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param audio_segment: Audio segment to add effects to.
|
||||||
|
:type audio_segment: AudioSegmenet|SpeechSegment
|
||||||
|
"""
|
||||||
|
noise_json = self._rng.sample(self._noise_manifest, 1)[0]
|
||||||
|
if noise_json['duration'] < audio_segment.duration:
|
||||||
|
raise RuntimeError("The duration of sampled noise audio is smaller "
|
||||||
|
"than the audio segment to add effects to.")
|
||||||
|
diff_duration = noise_json['duration'] - audio_segment.duration
|
||||||
|
start = self._rng.uniform(0, diff_duration)
|
||||||
|
end = start + audio_segment.duration
|
||||||
|
noise_segment = AudioSegment.slice_from_file(
|
||||||
|
noise_json['audio_filepath'], start=start, end=end)
|
||||||
|
snr_dB = self._rng.uniform(self._min_snr_dB, self._max_snr_dB)
|
||||||
|
audio_segment.add_noise(
|
||||||
|
noise_segment, snr_dB, allow_downsampling=True, rng=self._rng)
|
@ -0,0 +1,48 @@
|
|||||||
|
"""Contain the online bayesian normalization augmentation model."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from data_utils.augmentor.base import AugmentorBase
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineBayesianNormalizationAugmentor(AugmentorBase):
|
||||||
|
"""Augmentation model for adding online bayesian normalization.
|
||||||
|
|
||||||
|
:param rng: Random generator object.
|
||||||
|
:type rng: random.Random
|
||||||
|
:param target_db: Target RMS value in decibels.
|
||||||
|
:type target_db: float
|
||||||
|
:param prior_db: Prior RMS estimate in decibels.
|
||||||
|
:type prior_db: float
|
||||||
|
:param prior_samples: Prior strength in number of samples.
|
||||||
|
:type prior_samples: int
|
||||||
|
:param startup_delay: Default 0.0s. If provided, this function will
|
||||||
|
accrue statistics for the first startup_delay
|
||||||
|
seconds before applying online normalization.
|
||||||
|
:type starup_delay: float.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
rng,
|
||||||
|
target_db,
|
||||||
|
prior_db,
|
||||||
|
prior_samples,
|
||||||
|
startup_delay=0.0):
|
||||||
|
self._target_db = target_db
|
||||||
|
self._prior_db = prior_db
|
||||||
|
self._prior_samples = prior_samples
|
||||||
|
self._rng = rng
|
||||||
|
self._startup_delay = startup_delay
|
||||||
|
|
||||||
|
def transform_audio(self, audio_segment):
|
||||||
|
"""Normalizes the input audio using the online Bayesian approach.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param audio_segment: Audio segment to add effects to.
|
||||||
|
:type audio_segment: AudioSegment|SpeechSegment
|
||||||
|
"""
|
||||||
|
audio_segment.normalize_online_bayesian(self._target_db, self._prior_db,
|
||||||
|
self._prior_samples,
|
||||||
|
self._startup_delay)
|
@ -0,0 +1,33 @@
|
|||||||
|
"""Contain the resample augmentation model."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from data_utils.augmentor.base import AugmentorBase
|
||||||
|
|
||||||
|
|
||||||
|
class ResampleAugmentor(AugmentorBase):
|
||||||
|
"""Augmentation model for resampling.
|
||||||
|
|
||||||
|
See more info here:
|
||||||
|
https://ccrma.stanford.edu/~jos/resample/index.html
|
||||||
|
|
||||||
|
:param rng: Random generator object.
|
||||||
|
:type rng: random.Random
|
||||||
|
:param new_sample_rate: New sample rate in Hz.
|
||||||
|
:type new_sample_rate: int
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, rng, new_sample_rate):
|
||||||
|
self._new_sample_rate = new_sample_rate
|
||||||
|
self._rng = rng
|
||||||
|
|
||||||
|
def transform_audio(self, audio_segment):
|
||||||
|
"""Resamples the input audio to a target sample rate.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param audio: Audio segment to add effects to.
|
||||||
|
:type audio: AudioSegment|SpeechSegment
|
||||||
|
"""
|
||||||
|
audio_segment.resample(self._new_sample_rate)
|
@ -0,0 +1,34 @@
|
|||||||
|
"""Contains the volume perturb augmentation model."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from data_utils.augmentor.base import AugmentorBase
|
||||||
|
|
||||||
|
|
||||||
|
class ShiftPerturbAugmentor(AugmentorBase):
|
||||||
|
"""Augmentation model for adding random shift perturbation.
|
||||||
|
|
||||||
|
:param rng: Random generator object.
|
||||||
|
:type rng: random.Random
|
||||||
|
:param min_shift_ms: Minimal shift in milliseconds.
|
||||||
|
:type min_shift_ms: float
|
||||||
|
:param max_shift_ms: Maximal shift in milliseconds.
|
||||||
|
:type max_shift_ms: float
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, rng, min_shift_ms, max_shift_ms):
|
||||||
|
self._min_shift_ms = min_shift_ms
|
||||||
|
self._max_shift_ms = max_shift_ms
|
||||||
|
self._rng = rng
|
||||||
|
|
||||||
|
def transform_audio(self, audio_segment):
|
||||||
|
"""Shift audio.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param audio_segment: Audio segment to add effects to.
|
||||||
|
:type audio_segment: AudioSegmenet|SpeechSegment
|
||||||
|
"""
|
||||||
|
shift_ms = self._rng.uniform(self._min_shift_ms, self._max_shift_ms)
|
||||||
|
audio_segment.shift(shift_ms)
|
@ -0,0 +1,47 @@
|
|||||||
|
"""Contain the speech perturbation augmentation model."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from data_utils.augmentor.base import AugmentorBase
|
||||||
|
|
||||||
|
|
||||||
|
class SpeedPerturbAugmentor(AugmentorBase):
|
||||||
|
"""Augmentation model for adding speed perturbation.
|
||||||
|
|
||||||
|
See reference paper here:
|
||||||
|
http://www.danielpovey.com/files/2015_interspeech_augmentation.pdf
|
||||||
|
|
||||||
|
:param rng: Random generator object.
|
||||||
|
:type rng: random.Random
|
||||||
|
:param min_speed_rate: Lower bound of new speed rate to sample and should
|
||||||
|
not be smaller than 0.9.
|
||||||
|
:type min_speed_rate: float
|
||||||
|
:param max_speed_rate: Upper bound of new speed rate to sample and should
|
||||||
|
not be larger than 1.1.
|
||||||
|
:type max_speed_rate: float
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, rng, min_speed_rate, max_speed_rate):
|
||||||
|
if min_speed_rate < 0.9:
|
||||||
|
raise ValueError(
|
||||||
|
"Sampling speed below 0.9 can cause unnatural effects")
|
||||||
|
if max_speed_rate > 1.1:
|
||||||
|
raise ValueError(
|
||||||
|
"Sampling speed above 1.1 can cause unnatural effects")
|
||||||
|
self._min_speed_rate = min_speed_rate
|
||||||
|
self._max_speed_rate = max_speed_rate
|
||||||
|
self._rng = rng
|
||||||
|
|
||||||
|
def transform_audio(self, audio_segment):
|
||||||
|
"""Sample a new speed rate from the given range and
|
||||||
|
changes the speed of the given audio clip.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param audio_segment: Audio segment to add effects to.
|
||||||
|
:type audio_segment: AudioSegment|SpeechSegment
|
||||||
|
"""
|
||||||
|
sampled_speed = self._rng.uniform(self._min_speed_rate,
|
||||||
|
self._max_speed_rate)
|
||||||
|
audio_segment.change_speed(sampled_speed)
|
@ -0,0 +1,40 @@
|
|||||||
|
"""Contains the volume perturb augmentation model."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from data_utils.augmentor.base import AugmentorBase
|
||||||
|
|
||||||
|
|
||||||
|
class VolumePerturbAugmentor(AugmentorBase):
|
||||||
|
"""Augmentation model for adding random volume perturbation.
|
||||||
|
|
||||||
|
This is used for multi-loudness training of PCEN. See
|
||||||
|
|
||||||
|
https://arxiv.org/pdf/1607.05666v1.pdf
|
||||||
|
|
||||||
|
for more details.
|
||||||
|
|
||||||
|
:param rng: Random generator object.
|
||||||
|
:type rng: random.Random
|
||||||
|
:param min_gain_dBFS: Minimal gain in dBFS.
|
||||||
|
:type min_gain_dBFS: float
|
||||||
|
:param max_gain_dBFS: Maximal gain in dBFS.
|
||||||
|
:type max_gain_dBFS: float
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, rng, min_gain_dBFS, max_gain_dBFS):
|
||||||
|
self._min_gain_dBFS = min_gain_dBFS
|
||||||
|
self._max_gain_dBFS = max_gain_dBFS
|
||||||
|
self._rng = rng
|
||||||
|
|
||||||
|
def transform_audio(self, audio_segment):
|
||||||
|
"""Change audio loadness.
|
||||||
|
|
||||||
|
Note that this is an in-place transformation.
|
||||||
|
|
||||||
|
:param audio_segment: Audio segment to add effects to.
|
||||||
|
:type audio_segment: AudioSegmenet|SpeechSegment
|
||||||
|
"""
|
||||||
|
gain = self._rng.uniform(self._min_gain_dBFS, self._max_gain_dBFS)
|
||||||
|
audio_segment.gain_db(gain)
|
@ -0,0 +1,346 @@
|
|||||||
|
"""Contains data generator for orgnaizing various audio data preprocessing
|
||||||
|
pipeline and offering data reader interface of PaddlePaddle requirements.
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import random
|
||||||
|
import tarfile
|
||||||
|
import multiprocessing
|
||||||
|
import numpy as np
|
||||||
|
import paddle.v2 as paddle
|
||||||
|
from threading import local
|
||||||
|
from data_utils.utility import read_manifest
|
||||||
|
from data_utils.utility import xmap_readers_mp
|
||||||
|
from data_utils.augmentor.augmentation import AugmentationPipeline
|
||||||
|
from data_utils.featurizer.speech_featurizer import SpeechFeaturizer
|
||||||
|
from data_utils.speech import SpeechSegment
|
||||||
|
from data_utils.normalizer import FeatureNormalizer
|
||||||
|
|
||||||
|
|
||||||
|
class DataGenerator(object):
|
||||||
|
"""
|
||||||
|
DataGenerator provides basic audio data preprocessing pipeline, and offers
|
||||||
|
data reader interfaces of PaddlePaddle requirements.
|
||||||
|
|
||||||
|
:param vocab_filepath: Vocabulary filepath for indexing tokenized
|
||||||
|
transcripts.
|
||||||
|
:type vocab_filepath: basestring
|
||||||
|
:param mean_std_filepath: File containing the pre-computed mean and stddev.
|
||||||
|
:type mean_std_filepath: None|basestring
|
||||||
|
:param augmentation_config: Augmentation configuration in json string.
|
||||||
|
Details see AugmentationPipeline.__doc__.
|
||||||
|
:type augmentation_config: str
|
||||||
|
:param max_duration: Audio with duration (in seconds) greater than
|
||||||
|
this will be discarded.
|
||||||
|
:type max_duration: float
|
||||||
|
:param min_duration: Audio with duration (in seconds) smaller than
|
||||||
|
this will be discarded.
|
||||||
|
:type min_duration: float
|
||||||
|
:param stride_ms: Striding size (in milliseconds) for generating frames.
|
||||||
|
:type stride_ms: float
|
||||||
|
:param window_ms: Window size (in milliseconds) for generating frames.
|
||||||
|
:type window_ms: float
|
||||||
|
:param max_freq: Used when specgram_type is 'linear', only FFT bins
|
||||||
|
corresponding to frequencies between [0, max_freq] are
|
||||||
|
returned.
|
||||||
|
:types max_freq: None|float
|
||||||
|
:param specgram_type: Specgram feature type. Options: 'linear'.
|
||||||
|
:type specgram_type: str
|
||||||
|
:param use_dB_normalization: Whether to normalize the audio to -20 dB
|
||||||
|
before extracting the features.
|
||||||
|
:type use_dB_normalization: bool
|
||||||
|
:param num_threads: Number of CPU threads for processing data.
|
||||||
|
:type num_threads: int
|
||||||
|
:param random_seed: Random seed.
|
||||||
|
:type random_seed: int
|
||||||
|
:param keep_transcription_text: If set to True, transcription text will
|
||||||
|
be passed forward directly without
|
||||||
|
converting to index sequence.
|
||||||
|
:type keep_transcription_text: bool
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
vocab_filepath,
|
||||||
|
mean_std_filepath,
|
||||||
|
augmentation_config='{}',
|
||||||
|
max_duration=float('inf'),
|
||||||
|
min_duration=0.0,
|
||||||
|
stride_ms=10.0,
|
||||||
|
window_ms=20.0,
|
||||||
|
max_freq=None,
|
||||||
|
specgram_type='linear',
|
||||||
|
use_dB_normalization=True,
|
||||||
|
num_threads=multiprocessing.cpu_count() // 2,
|
||||||
|
random_seed=0,
|
||||||
|
keep_transcription_text=False):
|
||||||
|
self._max_duration = max_duration
|
||||||
|
self._min_duration = min_duration
|
||||||
|
self._normalizer = FeatureNormalizer(mean_std_filepath)
|
||||||
|
self._augmentation_pipeline = AugmentationPipeline(
|
||||||
|
augmentation_config=augmentation_config, random_seed=random_seed)
|
||||||
|
self._speech_featurizer = SpeechFeaturizer(
|
||||||
|
vocab_filepath=vocab_filepath,
|
||||||
|
specgram_type=specgram_type,
|
||||||
|
stride_ms=stride_ms,
|
||||||
|
window_ms=window_ms,
|
||||||
|
max_freq=max_freq,
|
||||||
|
use_dB_normalization=use_dB_normalization)
|
||||||
|
self._num_threads = num_threads
|
||||||
|
self._rng = random.Random(random_seed)
|
||||||
|
self._keep_transcription_text = keep_transcription_text
|
||||||
|
self._epoch = 0
|
||||||
|
# for caching tar files info
|
||||||
|
self._local_data = local()
|
||||||
|
self._local_data.tar2info = {}
|
||||||
|
self._local_data.tar2object = {}
|
||||||
|
|
||||||
|
def process_utterance(self, audio_file, transcript):
|
||||||
|
"""Load, augment, featurize and normalize for speech data.
|
||||||
|
|
||||||
|
:param audio_file: Filepath or file object of audio file.
|
||||||
|
:type audio_file: basestring | file
|
||||||
|
:param transcript: Transcription text.
|
||||||
|
:type transcript: basestring
|
||||||
|
:return: Tuple of audio feature tensor and data of transcription part,
|
||||||
|
where transcription part could be token ids or text.
|
||||||
|
:rtype: tuple of (2darray, list)
|
||||||
|
"""
|
||||||
|
if isinstance(audio_file, basestring) and audio_file.startswith('tar:'):
|
||||||
|
speech_segment = SpeechSegment.from_file(
|
||||||
|
self._subfile_from_tar(audio_file), transcript)
|
||||||
|
else:
|
||||||
|
speech_segment = SpeechSegment.from_file(audio_file, transcript)
|
||||||
|
self._augmentation_pipeline.transform_audio(speech_segment)
|
||||||
|
specgram, transcript_part = self._speech_featurizer.featurize(
|
||||||
|
speech_segment, self._keep_transcription_text)
|
||||||
|
specgram = self._normalizer.apply(specgram)
|
||||||
|
return specgram, transcript_part
|
||||||
|
|
||||||
|
def batch_reader_creator(self,
|
||||||
|
manifest_path,
|
||||||
|
batch_size,
|
||||||
|
min_batch_size=1,
|
||||||
|
padding_to=-1,
|
||||||
|
flatten=False,
|
||||||
|
sortagrad=False,
|
||||||
|
shuffle_method="batch_shuffle"):
|
||||||
|
"""
|
||||||
|
Batch data reader creator for audio data. Return a callable generator
|
||||||
|
function to produce batches of data.
|
||||||
|
|
||||||
|
Audio features within one batch will be padded with zeros to have the
|
||||||
|
same shape, or a user-defined shape.
|
||||||
|
|
||||||
|
:param manifest_path: Filepath of manifest for audio files.
|
||||||
|
:type manifest_path: basestring
|
||||||
|
:param batch_size: Number of instances in a batch.
|
||||||
|
:type batch_size: int
|
||||||
|
:param min_batch_size: Any batch with batch size smaller than this will
|
||||||
|
be discarded. (To be deprecated in the future.)
|
||||||
|
:type min_batch_size: int
|
||||||
|
:param padding_to: If set -1, the maximun shape in the batch
|
||||||
|
will be used as the target shape for padding.
|
||||||
|
Otherwise, `padding_to` will be the target shape.
|
||||||
|
:type padding_to: int
|
||||||
|
:param flatten: If set True, audio features will be flatten to 1darray.
|
||||||
|
:type flatten: bool
|
||||||
|
:param sortagrad: If set True, sort the instances by audio duration
|
||||||
|
in the first epoch for speed up training.
|
||||||
|
:type sortagrad: bool
|
||||||
|
:param shuffle_method: Shuffle method. Options:
|
||||||
|
'' or None: no shuffle.
|
||||||
|
'instance_shuffle': instance-wise shuffle.
|
||||||
|
'batch_shuffle': similarly-sized instances are
|
||||||
|
put into batches, and then
|
||||||
|
batch-wise shuffle the batches.
|
||||||
|
For more details, please see
|
||||||
|
``_batch_shuffle.__doc__``.
|
||||||
|
'batch_shuffle_clipped': 'batch_shuffle' with
|
||||||
|
head shift and tail
|
||||||
|
clipping. For more
|
||||||
|
details, please see
|
||||||
|
``_batch_shuffle``.
|
||||||
|
If sortagrad is True, shuffle is disabled
|
||||||
|
for the first epoch.
|
||||||
|
:type shuffle_method: None|str
|
||||||
|
:return: Batch reader function, producing batches of data when called.
|
||||||
|
:rtype: callable
|
||||||
|
"""
|
||||||
|
|
||||||
|
def batch_reader():
|
||||||
|
# read manifest
|
||||||
|
manifest = read_manifest(
|
||||||
|
manifest_path=manifest_path,
|
||||||
|
max_duration=self._max_duration,
|
||||||
|
min_duration=self._min_duration)
|
||||||
|
# sort (by duration) or batch-wise shuffle the manifest
|
||||||
|
if self._epoch == 0 and sortagrad:
|
||||||
|
manifest.sort(key=lambda x: x["duration"])
|
||||||
|
else:
|
||||||
|
if shuffle_method == "batch_shuffle":
|
||||||
|
manifest = self._batch_shuffle(
|
||||||
|
manifest, batch_size, clipped=False)
|
||||||
|
elif shuffle_method == "batch_shuffle_clipped":
|
||||||
|
manifest = self._batch_shuffle(
|
||||||
|
manifest, batch_size, clipped=True)
|
||||||
|
elif shuffle_method == "instance_shuffle":
|
||||||
|
self._rng.shuffle(manifest)
|
||||||
|
elif shuffle_method == None:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown shuffle method %s." %
|
||||||
|
shuffle_method)
|
||||||
|
# prepare batches
|
||||||
|
instance_reader, cleanup = self._instance_reader_creator(manifest)
|
||||||
|
batch = []
|
||||||
|
try:
|
||||||
|
for instance in instance_reader():
|
||||||
|
batch.append(instance)
|
||||||
|
if len(batch) == batch_size:
|
||||||
|
yield self._padding_batch(batch, padding_to, flatten)
|
||||||
|
batch = []
|
||||||
|
if len(batch) >= min_batch_size:
|
||||||
|
yield self._padding_batch(batch, padding_to, flatten)
|
||||||
|
finally:
|
||||||
|
cleanup()
|
||||||
|
self._epoch += 1
|
||||||
|
|
||||||
|
return batch_reader
|
||||||
|
|
||||||
|
@property
|
||||||
|
def feeding(self):
|
||||||
|
"""Returns data reader's feeding dict.
|
||||||
|
|
||||||
|
:return: Data feeding dict.
|
||||||
|
:rtype: dict
|
||||||
|
"""
|
||||||
|
feeding_dict = {"audio_spectrogram": 0, "transcript_text": 1}
|
||||||
|
return feeding_dict
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self):
|
||||||
|
"""Return the vocabulary size.
|
||||||
|
|
||||||
|
:return: Vocabulary size.
|
||||||
|
:rtype: int
|
||||||
|
"""
|
||||||
|
return self._speech_featurizer.vocab_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_list(self):
|
||||||
|
"""Return the vocabulary in list.
|
||||||
|
|
||||||
|
:return: Vocabulary in list.
|
||||||
|
:rtype: list
|
||||||
|
"""
|
||||||
|
return self._speech_featurizer.vocab_list
|
||||||
|
|
||||||
|
def _parse_tar(self, file):
|
||||||
|
"""Parse a tar file to get a tarfile object
|
||||||
|
and a map containing tarinfoes
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
f = tarfile.open(file)
|
||||||
|
for tarinfo in f.getmembers():
|
||||||
|
result[tarinfo.name] = tarinfo
|
||||||
|
return f, result
|
||||||
|
|
||||||
|
def _subfile_from_tar(self, file):
|
||||||
|
"""Get subfile object from tar.
|
||||||
|
|
||||||
|
It will return a subfile object from tar file
|
||||||
|
and cached tar file info for next reading request.
|
||||||
|
"""
|
||||||
|
tarpath, filename = file.split(':', 1)[1].split('#', 1)
|
||||||
|
if 'tar2info' not in self._local_data.__dict__:
|
||||||
|
self._local_data.tar2info = {}
|
||||||
|
if 'tar2object' not in self._local_data.__dict__:
|
||||||
|
self._local_data.tar2object = {}
|
||||||
|
if tarpath not in self._local_data.tar2info:
|
||||||
|
object, infoes = self._parse_tar(tarpath)
|
||||||
|
self._local_data.tar2info[tarpath] = infoes
|
||||||
|
self._local_data.tar2object[tarpath] = object
|
||||||
|
return self._local_data.tar2object[tarpath].extractfile(
|
||||||
|
self._local_data.tar2info[tarpath][filename])
|
||||||
|
|
||||||
|
def _instance_reader_creator(self, manifest):
|
||||||
|
"""
|
||||||
|
Instance reader creator. Create a callable function to produce
|
||||||
|
instances of data.
|
||||||
|
|
||||||
|
Instance: a tuple of ndarray of audio spectrogram and a list of
|
||||||
|
token indices for transcript.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def reader():
|
||||||
|
for instance in manifest:
|
||||||
|
yield instance
|
||||||
|
|
||||||
|
reader, cleanup_callback = xmap_readers_mp(
|
||||||
|
lambda instance: self.process_utterance(instance["audio_filepath"], instance["text"]),
|
||||||
|
reader, self._num_threads, 4096)
|
||||||
|
|
||||||
|
return reader, cleanup_callback
|
||||||
|
|
||||||
|
def _padding_batch(self, batch, padding_to=-1, flatten=False):
|
||||||
|
"""
|
||||||
|
Padding audio features with zeros to make them have the same shape (or
|
||||||
|
a user-defined shape) within one bach.
|
||||||
|
|
||||||
|
If ``padding_to`` is -1, the maximun shape in the batch will be used
|
||||||
|
as the target shape for padding. Otherwise, `padding_to` will be the
|
||||||
|
target shape (only refers to the second axis).
|
||||||
|
|
||||||
|
If `flatten` is True, features will be flatten to 1darray.
|
||||||
|
"""
|
||||||
|
new_batch = []
|
||||||
|
# get target shape
|
||||||
|
max_length = max([audio.shape[1] for audio, text in batch])
|
||||||
|
if padding_to != -1:
|
||||||
|
if padding_to < max_length:
|
||||||
|
raise ValueError("If padding_to is not -1, it should be larger "
|
||||||
|
"than any instance's shape in the batch")
|
||||||
|
max_length = padding_to
|
||||||
|
# padding
|
||||||
|
for audio, text in batch:
|
||||||
|
padded_audio = np.zeros([audio.shape[0], max_length])
|
||||||
|
padded_audio[:, :audio.shape[1]] = audio
|
||||||
|
if flatten:
|
||||||
|
padded_audio = padded_audio.flatten()
|
||||||
|
padded_instance = [padded_audio, text, audio.shape[1]]
|
||||||
|
new_batch.append(padded_instance)
|
||||||
|
return new_batch
|
||||||
|
|
||||||
|
def _batch_shuffle(self, manifest, batch_size, clipped=False):
|
||||||
|
"""Put similarly-sized instances into minibatches for better efficiency
|
||||||
|
and make a batch-wise shuffle.
|
||||||
|
|
||||||
|
1. Sort the audio clips by duration.
|
||||||
|
2. Generate a random number `k`, k in [0, batch_size).
|
||||||
|
3. Randomly shift `k` instances in order to create different batches
|
||||||
|
for different epochs. Create minibatches.
|
||||||
|
4. Shuffle the minibatches.
|
||||||
|
|
||||||
|
:param manifest: Manifest contents. List of dict.
|
||||||
|
:type manifest: list
|
||||||
|
:param batch_size: Batch size. This size is also used for generate
|
||||||
|
a random number for batch shuffle.
|
||||||
|
:type batch_size: int
|
||||||
|
:param clipped: Whether to clip the heading (small shift) and trailing
|
||||||
|
(incomplete batch) instances.
|
||||||
|
:type clipped: bool
|
||||||
|
:return: Batch shuffled mainifest.
|
||||||
|
:rtype: list
|
||||||
|
"""
|
||||||
|
manifest.sort(key=lambda x: x["duration"])
|
||||||
|
shift_len = self._rng.randint(0, batch_size - 1)
|
||||||
|
batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size)
|
||||||
|
self._rng.shuffle(batch_manifest)
|
||||||
|
batch_manifest = [item for batch in batch_manifest for item in batch]
|
||||||
|
if not clipped:
|
||||||
|
res_len = len(manifest) - shift_len - len(batch_manifest)
|
||||||
|
batch_manifest.extend(manifest[-res_len:])
|
||||||
|
batch_manifest.extend(manifest[0:shift_len])
|
||||||
|
return batch_manifest
|
@ -0,0 +1,187 @@
|
|||||||
|
"""Contains the audio featurizer class."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from data_utils.utility import read_manifest
|
||||||
|
from data_utils.audio import AudioSegment
|
||||||
|
from python_speech_features import mfcc
|
||||||
|
from python_speech_features import delta
|
||||||
|
|
||||||
|
|
||||||
|
class AudioFeaturizer(object):
|
||||||
|
"""Audio featurizer, for extracting features from audio contents of
|
||||||
|
AudioSegment or SpeechSegment.
|
||||||
|
|
||||||
|
Currently, it supports feature types of linear spectrogram and mfcc.
|
||||||
|
|
||||||
|
:param specgram_type: Specgram feature type. Options: 'linear'.
|
||||||
|
:type specgram_type: str
|
||||||
|
:param stride_ms: Striding size (in milliseconds) for generating frames.
|
||||||
|
:type stride_ms: float
|
||||||
|
:param window_ms: Window size (in milliseconds) for generating frames.
|
||||||
|
:type window_ms: float
|
||||||
|
:param max_freq: When specgram_type is 'linear', only FFT bins
|
||||||
|
corresponding to frequencies between [0, max_freq] are
|
||||||
|
returned; when specgram_type is 'mfcc', max_feq is the
|
||||||
|
highest band edge of mel filters.
|
||||||
|
:types max_freq: None|float
|
||||||
|
:param target_sample_rate: Audio are resampled (if upsampling or
|
||||||
|
downsampling is allowed) to this before
|
||||||
|
extracting spectrogram features.
|
||||||
|
:type target_sample_rate: float
|
||||||
|
:param use_dB_normalization: Whether to normalize the audio to a certain
|
||||||
|
decibels before extracting the features.
|
||||||
|
:type use_dB_normalization: bool
|
||||||
|
:param target_dB: Target audio decibels for normalization.
|
||||||
|
:type target_dB: float
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
specgram_type='linear',
|
||||||
|
stride_ms=10.0,
|
||||||
|
window_ms=20.0,
|
||||||
|
max_freq=None,
|
||||||
|
target_sample_rate=16000,
|
||||||
|
use_dB_normalization=True,
|
||||||
|
target_dB=-20):
|
||||||
|
self._specgram_type = specgram_type
|
||||||
|
self._stride_ms = stride_ms
|
||||||
|
self._window_ms = window_ms
|
||||||
|
self._max_freq = max_freq
|
||||||
|
self._target_sample_rate = target_sample_rate
|
||||||
|
self._use_dB_normalization = use_dB_normalization
|
||||||
|
self._target_dB = target_dB
|
||||||
|
|
||||||
|
def featurize(self,
|
||||||
|
audio_segment,
|
||||||
|
allow_downsampling=True,
|
||||||
|
allow_upsampling=True):
|
||||||
|
"""Extract audio features from AudioSegment or SpeechSegment.
|
||||||
|
|
||||||
|
:param audio_segment: Audio/speech segment to extract features from.
|
||||||
|
:type audio_segment: AudioSegment|SpeechSegment
|
||||||
|
:param allow_downsampling: Whether to allow audio downsampling before
|
||||||
|
featurizing.
|
||||||
|
:type allow_downsampling: bool
|
||||||
|
:param allow_upsampling: Whether to allow audio upsampling before
|
||||||
|
featurizing.
|
||||||
|
:type allow_upsampling: bool
|
||||||
|
:return: Spectrogram audio feature in 2darray.
|
||||||
|
:rtype: ndarray
|
||||||
|
:raises ValueError: If audio sample rate is not supported.
|
||||||
|
"""
|
||||||
|
# upsampling or downsampling
|
||||||
|
if ((audio_segment.sample_rate > self._target_sample_rate and
|
||||||
|
allow_downsampling) or
|
||||||
|
(audio_segment.sample_rate < self._target_sample_rate and
|
||||||
|
allow_upsampling)):
|
||||||
|
audio_segment.resample(self._target_sample_rate)
|
||||||
|
if audio_segment.sample_rate != self._target_sample_rate:
|
||||||
|
raise ValueError("Audio sample rate is not supported. "
|
||||||
|
"Turn allow_downsampling or allow up_sampling on.")
|
||||||
|
# decibel normalization
|
||||||
|
if self._use_dB_normalization:
|
||||||
|
audio_segment.normalize(target_db=self._target_dB)
|
||||||
|
# extract spectrogram
|
||||||
|
return self._compute_specgram(audio_segment.samples,
|
||||||
|
audio_segment.sample_rate)
|
||||||
|
|
||||||
|
def _compute_specgram(self, samples, sample_rate):
|
||||||
|
"""Extract various audio features."""
|
||||||
|
if self._specgram_type == 'linear':
|
||||||
|
return self._compute_linear_specgram(
|
||||||
|
samples, sample_rate, self._stride_ms, self._window_ms,
|
||||||
|
self._max_freq)
|
||||||
|
elif self._specgram_type == 'mfcc':
|
||||||
|
return self._compute_mfcc(samples, sample_rate, self._stride_ms,
|
||||||
|
self._window_ms, self._max_freq)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown specgram_type %s. "
|
||||||
|
"Supported values: linear." % self._specgram_type)
|
||||||
|
|
||||||
|
def _compute_linear_specgram(self,
|
||||||
|
samples,
|
||||||
|
sample_rate,
|
||||||
|
stride_ms=10.0,
|
||||||
|
window_ms=20.0,
|
||||||
|
max_freq=None,
|
||||||
|
eps=1e-14):
|
||||||
|
"""Compute the linear spectrogram from FFT energy."""
|
||||||
|
if max_freq is None:
|
||||||
|
max_freq = sample_rate / 2
|
||||||
|
if max_freq > sample_rate / 2:
|
||||||
|
raise ValueError("max_freq must not be greater than half of "
|
||||||
|
"sample rate.")
|
||||||
|
if stride_ms > window_ms:
|
||||||
|
raise ValueError("Stride size must not be greater than "
|
||||||
|
"window size.")
|
||||||
|
stride_size = int(0.001 * sample_rate * stride_ms)
|
||||||
|
window_size = int(0.001 * sample_rate * window_ms)
|
||||||
|
specgram, freqs = self._specgram_real(
|
||||||
|
samples,
|
||||||
|
window_size=window_size,
|
||||||
|
stride_size=stride_size,
|
||||||
|
sample_rate=sample_rate)
|
||||||
|
ind = np.where(freqs <= max_freq)[0][-1] + 1
|
||||||
|
return np.log(specgram[:ind, :] + eps)
|
||||||
|
|
||||||
|
def _specgram_real(self, samples, window_size, stride_size, sample_rate):
|
||||||
|
"""Compute the spectrogram for samples from a real signal."""
|
||||||
|
# extract strided windows
|
||||||
|
truncate_size = (len(samples) - window_size) % stride_size
|
||||||
|
samples = samples[:len(samples) - truncate_size]
|
||||||
|
nshape = (window_size, (len(samples) - window_size) // stride_size + 1)
|
||||||
|
nstrides = (samples.strides[0], samples.strides[0] * stride_size)
|
||||||
|
windows = np.lib.stride_tricks.as_strided(
|
||||||
|
samples, shape=nshape, strides=nstrides)
|
||||||
|
assert np.all(
|
||||||
|
windows[:, 1] == samples[stride_size:(stride_size + window_size)])
|
||||||
|
# window weighting, squared Fast Fourier Transform (fft), scaling
|
||||||
|
weighting = np.hanning(window_size)[:, None]
|
||||||
|
fft = np.fft.rfft(windows * weighting, axis=0)
|
||||||
|
fft = np.absolute(fft)
|
||||||
|
fft = fft**2
|
||||||
|
scale = np.sum(weighting**2) * sample_rate
|
||||||
|
fft[1:-1, :] *= (2.0 / scale)
|
||||||
|
fft[(0, -1), :] /= scale
|
||||||
|
# prepare fft frequency list
|
||||||
|
freqs = float(sample_rate) / window_size * np.arange(fft.shape[0])
|
||||||
|
return fft, freqs
|
||||||
|
|
||||||
|
def _compute_mfcc(self,
|
||||||
|
samples,
|
||||||
|
sample_rate,
|
||||||
|
stride_ms=10.0,
|
||||||
|
window_ms=20.0,
|
||||||
|
max_freq=None):
|
||||||
|
"""Compute mfcc from samples."""
|
||||||
|
if max_freq is None:
|
||||||
|
max_freq = sample_rate / 2
|
||||||
|
if max_freq > sample_rate / 2:
|
||||||
|
raise ValueError("max_freq must not be greater than half of "
|
||||||
|
"sample rate.")
|
||||||
|
if stride_ms > window_ms:
|
||||||
|
raise ValueError("Stride size must not be greater than "
|
||||||
|
"window size.")
|
||||||
|
# compute the 13 cepstral coefficients, and the first one is replaced
|
||||||
|
# by log(frame energy)
|
||||||
|
mfcc_feat = mfcc(
|
||||||
|
signal=samples,
|
||||||
|
samplerate=sample_rate,
|
||||||
|
winlen=0.001 * window_ms,
|
||||||
|
winstep=0.001 * stride_ms,
|
||||||
|
highfreq=max_freq)
|
||||||
|
# Deltas
|
||||||
|
d_mfcc_feat = delta(mfcc_feat, 2)
|
||||||
|
# Deltas-Deltas
|
||||||
|
dd_mfcc_feat = delta(d_mfcc_feat, 2)
|
||||||
|
# transpose
|
||||||
|
mfcc_feat = np.transpose(mfcc_feat)
|
||||||
|
d_mfcc_feat = np.transpose(d_mfcc_feat)
|
||||||
|
dd_mfcc_feat = np.transpose(dd_mfcc_feat)
|
||||||
|
# concat above three features
|
||||||
|
concat_mfcc_feat = np.concatenate(
|
||||||
|
(mfcc_feat, d_mfcc_feat, dd_mfcc_feat))
|
||||||
|
return concat_mfcc_feat
|
@ -0,0 +1,98 @@
|
|||||||
|
"""Contains the speech featurizer class."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from data_utils.featurizer.audio_featurizer import AudioFeaturizer
|
||||||
|
from data_utils.featurizer.text_featurizer import TextFeaturizer
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechFeaturizer(object):
|
||||||
|
"""Speech featurizer, for extracting features from both audio and transcript
|
||||||
|
contents of SpeechSegment.
|
||||||
|
|
||||||
|
Currently, for audio parts, it supports feature types of linear
|
||||||
|
spectrogram and mfcc; for transcript parts, it only supports char-level
|
||||||
|
tokenizing and conversion into a list of token indices. Note that the
|
||||||
|
token indexing order follows the given vocabulary file.
|
||||||
|
|
||||||
|
:param vocab_filepath: Filepath to load vocabulary for token indices
|
||||||
|
conversion.
|
||||||
|
:type specgram_type: basestring
|
||||||
|
:param specgram_type: Specgram feature type. Options: 'linear', 'mfcc'.
|
||||||
|
:type specgram_type: str
|
||||||
|
:param stride_ms: Striding size (in milliseconds) for generating frames.
|
||||||
|
:type stride_ms: float
|
||||||
|
:param window_ms: Window size (in milliseconds) for generating frames.
|
||||||
|
:type window_ms: float
|
||||||
|
:param max_freq: When specgram_type is 'linear', only FFT bins
|
||||||
|
corresponding to frequencies between [0, max_freq] are
|
||||||
|
returned; when specgram_type is 'mfcc', max_freq is the
|
||||||
|
highest band edge of mel filters.
|
||||||
|
:types max_freq: None|float
|
||||||
|
:param target_sample_rate: Speech are resampled (if upsampling or
|
||||||
|
downsampling is allowed) to this before
|
||||||
|
extracting spectrogram features.
|
||||||
|
:type target_sample_rate: float
|
||||||
|
:param use_dB_normalization: Whether to normalize the audio to a certain
|
||||||
|
decibels before extracting the features.
|
||||||
|
:type use_dB_normalization: bool
|
||||||
|
:param target_dB: Target audio decibels for normalization.
|
||||||
|
:type target_dB: float
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
vocab_filepath,
|
||||||
|
specgram_type='linear',
|
||||||
|
stride_ms=10.0,
|
||||||
|
window_ms=20.0,
|
||||||
|
max_freq=None,
|
||||||
|
target_sample_rate=16000,
|
||||||
|
use_dB_normalization=True,
|
||||||
|
target_dB=-20):
|
||||||
|
self._audio_featurizer = AudioFeaturizer(
|
||||||
|
specgram_type=specgram_type,
|
||||||
|
stride_ms=stride_ms,
|
||||||
|
window_ms=window_ms,
|
||||||
|
max_freq=max_freq,
|
||||||
|
target_sample_rate=target_sample_rate,
|
||||||
|
use_dB_normalization=use_dB_normalization,
|
||||||
|
target_dB=target_dB)
|
||||||
|
self._text_featurizer = TextFeaturizer(vocab_filepath)
|
||||||
|
|
||||||
|
def featurize(self, speech_segment, keep_transcription_text):
|
||||||
|
"""Extract features for speech segment.
|
||||||
|
|
||||||
|
1. For audio parts, extract the audio features.
|
||||||
|
2. For transcript parts, keep the original text or convert text string
|
||||||
|
to a list of token indices in char-level.
|
||||||
|
|
||||||
|
:param audio_segment: Speech segment to extract features from.
|
||||||
|
:type audio_segment: SpeechSegment
|
||||||
|
:return: A tuple of 1) spectrogram audio feature in 2darray, 2) list of
|
||||||
|
char-level token indices.
|
||||||
|
:rtype: tuple
|
||||||
|
"""
|
||||||
|
audio_feature = self._audio_featurizer.featurize(speech_segment)
|
||||||
|
if keep_transcription_text:
|
||||||
|
return audio_feature, speech_segment.transcript
|
||||||
|
text_ids = self._text_featurizer.featurize(speech_segment.transcript)
|
||||||
|
return audio_feature, text_ids
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self):
|
||||||
|
"""Return the vocabulary size.
|
||||||
|
|
||||||
|
:return: Vocabulary size.
|
||||||
|
:rtype: int
|
||||||
|
"""
|
||||||
|
return self._text_featurizer.vocab_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_list(self):
|
||||||
|
"""Return the vocabulary in list.
|
||||||
|
|
||||||
|
:return: Vocabulary in list.
|
||||||
|
:rtype: list
|
||||||
|
"""
|
||||||
|
return self._text_featurizer.vocab_list
|
@ -0,0 +1,68 @@
|
|||||||
|
"""Contains the text featurizer class."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import codecs
|
||||||
|
|
||||||
|
|
||||||
|
class TextFeaturizer(object):
|
||||||
|
"""Text featurizer, for processing or extracting features from text.
|
||||||
|
|
||||||
|
Currently, it only supports char-level tokenizing and conversion into
|
||||||
|
a list of token indices. Note that the token indexing order follows the
|
||||||
|
given vocabulary file.
|
||||||
|
|
||||||
|
:param vocab_filepath: Filepath to load vocabulary for token indices
|
||||||
|
conversion.
|
||||||
|
:type specgram_type: basestring
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, vocab_filepath):
|
||||||
|
self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file(
|
||||||
|
vocab_filepath)
|
||||||
|
|
||||||
|
def featurize(self, text):
|
||||||
|
"""Convert text string to a list of token indices in char-level.Note
|
||||||
|
that the token indexing order follows the given vocabulary file.
|
||||||
|
|
||||||
|
:param text: Text to process.
|
||||||
|
:type text: basestring
|
||||||
|
:return: List of char-level token indices.
|
||||||
|
:rtype: list
|
||||||
|
"""
|
||||||
|
tokens = self._char_tokenize(text)
|
||||||
|
return [self._vocab_dict[token] for token in tokens]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self):
|
||||||
|
"""Return the vocabulary size.
|
||||||
|
|
||||||
|
:return: Vocabulary size.
|
||||||
|
:rtype: int
|
||||||
|
"""
|
||||||
|
return len(self._vocab_list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_list(self):
|
||||||
|
"""Return the vocabulary in list.
|
||||||
|
|
||||||
|
:return: Vocabulary in list.
|
||||||
|
:rtype: list
|
||||||
|
"""
|
||||||
|
return self._vocab_list
|
||||||
|
|
||||||
|
def _char_tokenize(self, text):
|
||||||
|
"""Character tokenizer."""
|
||||||
|
return list(text.strip())
|
||||||
|
|
||||||
|
def _load_vocabulary_from_file(self, vocab_filepath):
|
||||||
|
"""Load vocabulary from file."""
|
||||||
|
vocab_lines = []
|
||||||
|
with codecs.open(vocab_filepath, 'r', 'utf-8') as file:
|
||||||
|
vocab_lines.extend(file.readlines())
|
||||||
|
vocab_list = [line[:-1] for line in vocab_lines]
|
||||||
|
vocab_dict = dict(
|
||||||
|
[(token, id) for (id, token) in enumerate(vocab_list)])
|
||||||
|
return vocab_dict, vocab_list
|
@ -0,0 +1,87 @@
|
|||||||
|
"""Contains feature normalizers."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
from data_utils.utility import read_manifest
|
||||||
|
from data_utils.audio import AudioSegment
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureNormalizer(object):
|
||||||
|
"""Feature normalizer. Normalize features to be of zero mean and unit
|
||||||
|
stddev.
|
||||||
|
|
||||||
|
if mean_std_filepath is provided (not None), the normalizer will directly
|
||||||
|
initilize from the file. Otherwise, both manifest_path and featurize_func
|
||||||
|
should be given for on-the-fly mean and stddev computing.
|
||||||
|
|
||||||
|
:param mean_std_filepath: File containing the pre-computed mean and stddev.
|
||||||
|
:type mean_std_filepath: None|basestring
|
||||||
|
:param manifest_path: Manifest of instances for computing mean and stddev.
|
||||||
|
:type meanifest_path: None|basestring
|
||||||
|
:param featurize_func: Function to extract features. It should be callable
|
||||||
|
with ``featurize_func(audio_segment)``.
|
||||||
|
:type featurize_func: None|callable
|
||||||
|
:param num_samples: Number of random samples for computing mean and stddev.
|
||||||
|
:type num_samples: int
|
||||||
|
:param random_seed: Random seed for sampling instances.
|
||||||
|
:type random_seed: int
|
||||||
|
:raises ValueError: If both mean_std_filepath and manifest_path
|
||||||
|
(or both mean_std_filepath and featurize_func) are None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
mean_std_filepath,
|
||||||
|
manifest_path=None,
|
||||||
|
featurize_func=None,
|
||||||
|
num_samples=500,
|
||||||
|
random_seed=0):
|
||||||
|
if not mean_std_filepath:
|
||||||
|
if not (manifest_path and featurize_func):
|
||||||
|
raise ValueError("If mean_std_filepath is None, meanifest_path "
|
||||||
|
"and featurize_func should not be None.")
|
||||||
|
self._rng = random.Random(random_seed)
|
||||||
|
self._compute_mean_std(manifest_path, featurize_func, num_samples)
|
||||||
|
else:
|
||||||
|
self._read_mean_std_from_file(mean_std_filepath)
|
||||||
|
|
||||||
|
def apply(self, features, eps=1e-14):
|
||||||
|
"""Normalize features to be of zero mean and unit stddev.
|
||||||
|
|
||||||
|
:param features: Input features to be normalized.
|
||||||
|
:type features: ndarray
|
||||||
|
:param eps: added to stddev to provide numerical stablibity.
|
||||||
|
:type eps: float
|
||||||
|
:return: Normalized features.
|
||||||
|
:rtype: ndarray
|
||||||
|
"""
|
||||||
|
return (features - self._mean) / (self._std + eps)
|
||||||
|
|
||||||
|
def write_to_file(self, filepath):
|
||||||
|
"""Write the mean and stddev to the file.
|
||||||
|
|
||||||
|
:param filepath: File to write mean and stddev.
|
||||||
|
:type filepath: basestring
|
||||||
|
"""
|
||||||
|
np.savez(filepath, mean=self._mean, std=self._std)
|
||||||
|
|
||||||
|
def _read_mean_std_from_file(self, filepath):
|
||||||
|
"""Load mean and std from file."""
|
||||||
|
npzfile = np.load(filepath)
|
||||||
|
self._mean = npzfile["mean"]
|
||||||
|
self._std = npzfile["std"]
|
||||||
|
|
||||||
|
def _compute_mean_std(self, manifest_path, featurize_func, num_samples):
|
||||||
|
"""Compute mean and std from randomly sampled instances."""
|
||||||
|
manifest = read_manifest(manifest_path)
|
||||||
|
sampled_manifest = self._rng.sample(manifest, num_samples)
|
||||||
|
features = []
|
||||||
|
for instance in sampled_manifest:
|
||||||
|
features.append(
|
||||||
|
featurize_func(
|
||||||
|
AudioSegment.from_file(instance["audio_filepath"])))
|
||||||
|
features = np.hstack(features)
|
||||||
|
self._mean = np.mean(features, axis=1).reshape([-1, 1])
|
||||||
|
self._std = np.std(features, axis=1).reshape([-1, 1])
|
@ -0,0 +1,143 @@
|
|||||||
|
"""Contains the speech segment class."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from data_utils.audio import AudioSegment
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechSegment(AudioSegment):
|
||||||
|
"""Speech segment abstraction, a subclass of AudioSegment,
|
||||||
|
with an additional transcript.
|
||||||
|
|
||||||
|
:param samples: Audio samples [num_samples x num_channels].
|
||||||
|
:type samples: ndarray.float32
|
||||||
|
:param sample_rate: Audio sample rate.
|
||||||
|
:type sample_rate: int
|
||||||
|
:param transcript: Transcript text for the speech.
|
||||||
|
:type transript: basestring
|
||||||
|
:raises TypeError: If the sample data type is not float or int.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, samples, sample_rate, transcript):
|
||||||
|
AudioSegment.__init__(self, samples, sample_rate)
|
||||||
|
self._transcript = transcript
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
"""Return whether two objects are equal.
|
||||||
|
"""
|
||||||
|
if not AudioSegment.__eq__(self, other):
|
||||||
|
return False
|
||||||
|
if self._transcript != other._transcript:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __ne__(self, other):
|
||||||
|
"""Return whether two objects are unequal."""
|
||||||
|
return not self.__eq__(other)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_file(cls, filepath, transcript):
|
||||||
|
"""Create speech segment from audio file and corresponding transcript.
|
||||||
|
|
||||||
|
:param filepath: Filepath or file object to audio file.
|
||||||
|
:type filepath: basestring|file
|
||||||
|
:param transcript: Transcript text for the speech.
|
||||||
|
:type transript: basestring
|
||||||
|
:return: Speech segment instance.
|
||||||
|
:rtype: SpeechSegment
|
||||||
|
"""
|
||||||
|
audio = AudioSegment.from_file(filepath)
|
||||||
|
return cls(audio.samples, audio.sample_rate, transcript)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytes(cls, bytes, transcript):
|
||||||
|
"""Create speech segment from a byte string and corresponding
|
||||||
|
transcript.
|
||||||
|
|
||||||
|
:param bytes: Byte string containing audio samples.
|
||||||
|
:type bytes: str
|
||||||
|
:param transcript: Transcript text for the speech.
|
||||||
|
:type transript: basestring
|
||||||
|
:return: Speech segment instance.
|
||||||
|
:rtype: Speech Segment
|
||||||
|
"""
|
||||||
|
audio = AudioSegment.from_bytes(bytes)
|
||||||
|
return cls(audio.samples, audio.sample_rate, transcript)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def concatenate(cls, *segments):
|
||||||
|
"""Concatenate an arbitrary number of speech segments together, both
|
||||||
|
audio and transcript will be concatenated.
|
||||||
|
|
||||||
|
:param *segments: Input speech segments to be concatenated.
|
||||||
|
:type *segments: tuple of SpeechSegment
|
||||||
|
:return: Speech segment instance.
|
||||||
|
:rtype: SpeechSegment
|
||||||
|
:raises ValueError: If the number of segments is zero, or if the
|
||||||
|
sample_rate of any two segments does not match.
|
||||||
|
:raises TypeError: If any segment is not SpeechSegment instance.
|
||||||
|
"""
|
||||||
|
if len(segments) == 0:
|
||||||
|
raise ValueError("No speech segments are given to concatenate.")
|
||||||
|
sample_rate = segments[0]._sample_rate
|
||||||
|
transcripts = ""
|
||||||
|
for seg in segments:
|
||||||
|
if sample_rate != seg._sample_rate:
|
||||||
|
raise ValueError("Can't concatenate segments with "
|
||||||
|
"different sample rates")
|
||||||
|
if type(seg) is not cls:
|
||||||
|
raise TypeError("Only speech segments of the same type "
|
||||||
|
"instance can be concatenated.")
|
||||||
|
transcripts += seg._transcript
|
||||||
|
samples = np.concatenate([seg.samples for seg in segments])
|
||||||
|
return cls(samples, sample_rate, transcripts)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def slice_from_file(cls, filepath, transcript, start=None, end=None):
|
||||||
|
"""Loads a small section of an speech without having to load
|
||||||
|
the entire file into the memory which can be incredibly wasteful.
|
||||||
|
|
||||||
|
:param filepath: Filepath or file object to audio file.
|
||||||
|
:type filepath: basestring|file
|
||||||
|
:param start: Start time in seconds. If start is negative, it wraps
|
||||||
|
around from the end. If not provided, this function
|
||||||
|
reads from the very beginning.
|
||||||
|
:type start: float
|
||||||
|
:param end: End time in seconds. If end is negative, it wraps around
|
||||||
|
from the end. If not provided, the default behvaior is
|
||||||
|
to read to the end of the file.
|
||||||
|
:type end: float
|
||||||
|
:param transcript: Transcript text for the speech. if not provided,
|
||||||
|
the defaults is an empty string.
|
||||||
|
:type transript: basestring
|
||||||
|
:return: SpeechSegment instance of the specified slice of the input
|
||||||
|
speech file.
|
||||||
|
:rtype: SpeechSegment
|
||||||
|
"""
|
||||||
|
audio = AudioSegment.slice_from_file(filepath, start, end)
|
||||||
|
return cls(audio.samples, audio.sample_rate, transcript)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def make_silence(cls, duration, sample_rate):
|
||||||
|
"""Creates a silent speech segment of the given duration and
|
||||||
|
sample rate, transcript will be an empty string.
|
||||||
|
|
||||||
|
:param duration: Length of silence in seconds.
|
||||||
|
:type duration: float
|
||||||
|
:param sample_rate: Sample rate.
|
||||||
|
:type sample_rate: float
|
||||||
|
:return: Silence of the given duration.
|
||||||
|
:rtype: SpeechSegment
|
||||||
|
"""
|
||||||
|
audio = AudioSegment.make_silence(duration, sample_rate)
|
||||||
|
return cls(audio.samples, audio.sample_rate, "")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def transcript(self):
|
||||||
|
"""Return the transcript text.
|
||||||
|
|
||||||
|
:return: Transcript text for the speech.
|
||||||
|
:rtype: basestring
|
||||||
|
"""
|
||||||
|
return self._transcript
|
@ -0,0 +1,214 @@
|
|||||||
|
"""Contains data helper functions."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import json
|
||||||
|
import codecs
|
||||||
|
import os
|
||||||
|
import tarfile
|
||||||
|
import time
|
||||||
|
from Queue import Queue
|
||||||
|
from threading import Thread
|
||||||
|
from multiprocessing import Process, Manager, Value
|
||||||
|
from paddle.v2.dataset.common import md5file
|
||||||
|
|
||||||
|
|
||||||
|
def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0):
|
||||||
|
"""Load and parse manifest file.
|
||||||
|
|
||||||
|
Instances with durations outside [min_duration, max_duration] will be
|
||||||
|
filtered out.
|
||||||
|
|
||||||
|
:param manifest_path: Manifest file to load and parse.
|
||||||
|
:type manifest_path: basestring
|
||||||
|
:param max_duration: Maximal duration in seconds for instance filter.
|
||||||
|
:type max_duration: float
|
||||||
|
:param min_duration: Minimal duration in seconds for instance filter.
|
||||||
|
:type min_duration: float
|
||||||
|
:return: Manifest parsing results. List of dict.
|
||||||
|
:rtype: list
|
||||||
|
:raises IOError: If failed to parse the manifest.
|
||||||
|
"""
|
||||||
|
manifest = []
|
||||||
|
for json_line in codecs.open(manifest_path, 'r', 'utf-8'):
|
||||||
|
try:
|
||||||
|
json_data = json.loads(json_line)
|
||||||
|
except Exception as e:
|
||||||
|
raise IOError("Error reading manifest: %s" % str(e))
|
||||||
|
if (json_data["duration"] <= max_duration and
|
||||||
|
json_data["duration"] >= min_duration):
|
||||||
|
manifest.append(json_data)
|
||||||
|
return manifest
|
||||||
|
|
||||||
|
|
||||||
|
def getfile_insensitive(path):
|
||||||
|
"""Get the actual file path when given insensitive filename."""
|
||||||
|
directory, filename = os.path.split(path)
|
||||||
|
directory, filename = (directory or '.'), filename.lower()
|
||||||
|
for f in os.listdir(directory):
|
||||||
|
newpath = os.path.join(directory, f)
|
||||||
|
if os.path.isfile(newpath) and f.lower() == filename:
|
||||||
|
return newpath
|
||||||
|
|
||||||
|
|
||||||
|
def download_multi(url, target_dir, extra_args):
|
||||||
|
"""Download multiple files from url to target_dir."""
|
||||||
|
if not os.path.exists(target_dir): os.makedirs(target_dir)
|
||||||
|
print("Downloading %s ..." % url)
|
||||||
|
ret_code = os.system("wget -c " + url + ' ' + extra_args + " -P " +
|
||||||
|
target_dir)
|
||||||
|
return ret_code
|
||||||
|
|
||||||
|
|
||||||
|
def download(url, md5sum, target_dir):
|
||||||
|
"""Download file from url to target_dir, and check md5sum."""
|
||||||
|
if not os.path.exists(target_dir): os.makedirs(target_dir)
|
||||||
|
filepath = os.path.join(target_dir, url.split("/")[-1])
|
||||||
|
if not (os.path.exists(filepath) and md5file(filepath) == md5sum):
|
||||||
|
print("Downloading %s ..." % url)
|
||||||
|
os.system("wget -c " + url + " -P " + target_dir)
|
||||||
|
print("\nMD5 Chesksum %s ..." % filepath)
|
||||||
|
if not md5file(filepath) == md5sum:
|
||||||
|
raise RuntimeError("MD5 checksum failed.")
|
||||||
|
else:
|
||||||
|
print("File exists, skip downloading. (%s)" % filepath)
|
||||||
|
return filepath
|
||||||
|
|
||||||
|
|
||||||
|
def unpack(filepath, target_dir, rm_tar=False):
|
||||||
|
"""Unpack the file to the target_dir."""
|
||||||
|
print("Unpacking %s ..." % filepath)
|
||||||
|
tar = tarfile.open(filepath)
|
||||||
|
tar.extractall(target_dir)
|
||||||
|
tar.close()
|
||||||
|
if rm_tar == True:
|
||||||
|
os.remove(filepath)
|
||||||
|
|
||||||
|
|
||||||
|
class XmapEndSignal():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
|
||||||
|
"""A multiprocessing pipeline wrapper for the data reader.
|
||||||
|
|
||||||
|
:param mapper: Function to map sample.
|
||||||
|
:type mapper: callable
|
||||||
|
:param reader: Given data reader.
|
||||||
|
:type reader: callable
|
||||||
|
:param process_num: Number of processes in the pipeline
|
||||||
|
:type process_num: int
|
||||||
|
:param buffer_size: Maximal buffer size.
|
||||||
|
:type buffer_size: int
|
||||||
|
:return: The wrappered reader and cleanup callback
|
||||||
|
:rtype: tuple
|
||||||
|
"""
|
||||||
|
end_flag = XmapEndSignal()
|
||||||
|
|
||||||
|
read_workers = []
|
||||||
|
handle_workers = []
|
||||||
|
flush_workers = []
|
||||||
|
|
||||||
|
read_exit_flag = Value('i', 0)
|
||||||
|
handle_exit_flag = Value('i', 0)
|
||||||
|
flush_exit_flag = Value('i', 0)
|
||||||
|
|
||||||
|
# define a worker to read samples from reader to in_queue with order flag
|
||||||
|
def order_read_worker(reader, in_queue):
|
||||||
|
for order_id, sample in enumerate(reader()):
|
||||||
|
if read_exit_flag.value == 1: break
|
||||||
|
in_queue.put((order_id, sample))
|
||||||
|
in_queue.put(end_flag)
|
||||||
|
# the reading worker should not exit until all handling work exited
|
||||||
|
while handle_exit_flag.value == 0 or read_exit_flag.value == 0:
|
||||||
|
time.sleep(0.001)
|
||||||
|
|
||||||
|
# define a worker to handle samples from in_queue by mapper and put results
|
||||||
|
# to out_queue with order
|
||||||
|
def order_handle_worker(in_queue, out_queue, mapper, out_order):
|
||||||
|
ins = in_queue.get()
|
||||||
|
while not isinstance(ins, XmapEndSignal):
|
||||||
|
if handle_exit_flag.value == 1: break
|
||||||
|
order_id, sample = ins
|
||||||
|
result = mapper(sample)
|
||||||
|
while order_id != out_order[0]:
|
||||||
|
time.sleep(0.001)
|
||||||
|
out_queue.put(result)
|
||||||
|
out_order[0] += 1
|
||||||
|
ins = in_queue.get()
|
||||||
|
in_queue.put(end_flag)
|
||||||
|
out_queue.put(end_flag)
|
||||||
|
# wait for exit of flushing worker
|
||||||
|
while flush_exit_flag.value == 0 or handle_exit_flag.value == 0:
|
||||||
|
time.sleep(0.001)
|
||||||
|
read_exit_flag.value = 1
|
||||||
|
handle_exit_flag.value = 1
|
||||||
|
|
||||||
|
# define a thread worker to flush samples from Manager.Queue to Queue
|
||||||
|
# for acceleration
|
||||||
|
def flush_worker(in_queue, out_queue):
|
||||||
|
finish = 0
|
||||||
|
while finish < process_num and flush_exit_flag.value == 0:
|
||||||
|
sample = in_queue.get()
|
||||||
|
if isinstance(sample, XmapEndSignal):
|
||||||
|
finish += 1
|
||||||
|
else:
|
||||||
|
out_queue.put(sample)
|
||||||
|
out_queue.put(end_flag)
|
||||||
|
handle_exit_flag.value = 1
|
||||||
|
flush_exit_flag.value = 1
|
||||||
|
|
||||||
|
def cleanup():
|
||||||
|
# first exit flushing workers
|
||||||
|
flush_exit_flag.value = 1
|
||||||
|
for w in flush_workers:
|
||||||
|
w.join()
|
||||||
|
# next exit handling workers
|
||||||
|
handle_exit_flag.value = 1
|
||||||
|
for w in handle_workers:
|
||||||
|
w.join()
|
||||||
|
# last exit reading workers
|
||||||
|
read_exit_flag.value = 1
|
||||||
|
for w in read_workers:
|
||||||
|
w.join()
|
||||||
|
|
||||||
|
def xreader():
|
||||||
|
# prepare shared memory
|
||||||
|
manager = Manager()
|
||||||
|
in_queue = manager.Queue(buffer_size)
|
||||||
|
out_queue = manager.Queue(buffer_size)
|
||||||
|
out_order = manager.list([0])
|
||||||
|
|
||||||
|
# start a read worker in a process
|
||||||
|
target = order_read_worker
|
||||||
|
p = Process(target=target, args=(reader, in_queue))
|
||||||
|
p.daemon = True
|
||||||
|
p.start()
|
||||||
|
read_workers.append(p)
|
||||||
|
|
||||||
|
# start handle_workers with multiple processes
|
||||||
|
target = order_handle_worker
|
||||||
|
args = (in_queue, out_queue, mapper, out_order)
|
||||||
|
workers = [
|
||||||
|
Process(target=target, args=args) for _ in xrange(process_num)
|
||||||
|
]
|
||||||
|
for w in workers:
|
||||||
|
w.daemon = True
|
||||||
|
w.start()
|
||||||
|
handle_workers.append(w)
|
||||||
|
|
||||||
|
# start a thread to read data from slow Manager.Queue
|
||||||
|
flush_queue = Queue(buffer_size)
|
||||||
|
t = Thread(target=flush_worker, args=(out_queue, flush_queue))
|
||||||
|
t.daemon = True
|
||||||
|
t.start()
|
||||||
|
flush_workers.append(t)
|
||||||
|
|
||||||
|
# get results
|
||||||
|
sample = flush_queue.get()
|
||||||
|
while not isinstance(sample, XmapEndSignal):
|
||||||
|
yield sample
|
||||||
|
sample = flush_queue.get()
|
||||||
|
|
||||||
|
return xreader, cleanup
|
@ -0,0 +1,238 @@
|
|||||||
|
"""Contains various CTC decoders."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from itertools import groupby
|
||||||
|
import numpy as np
|
||||||
|
from math import log
|
||||||
|
import multiprocessing
|
||||||
|
|
||||||
|
|
||||||
|
def ctc_greedy_decoder(probs_seq, vocabulary):
|
||||||
|
"""CTC greedy (best path) decoder.
|
||||||
|
|
||||||
|
Path consisting of the most probable tokens are further post-processed to
|
||||||
|
remove consecutive repetitions and all blanks.
|
||||||
|
|
||||||
|
:param probs_seq: 2-D list of probabilities over the vocabulary for each
|
||||||
|
character. Each element is a list of float probabilities
|
||||||
|
for one character.
|
||||||
|
:type probs_seq: list
|
||||||
|
:param vocabulary: Vocabulary list.
|
||||||
|
:type vocabulary: list
|
||||||
|
:return: Decoding result string.
|
||||||
|
:rtype: baseline
|
||||||
|
"""
|
||||||
|
# dimension verification
|
||||||
|
for probs in probs_seq:
|
||||||
|
if not len(probs) == len(vocabulary) + 1:
|
||||||
|
raise ValueError("probs_seq dimension mismatchedd with vocabulary")
|
||||||
|
# argmax to get the best index for each time step
|
||||||
|
max_index_list = list(np.array(probs_seq).argmax(axis=1))
|
||||||
|
# remove consecutive duplicate indexes
|
||||||
|
index_list = [index_group[0] for index_group in groupby(max_index_list)]
|
||||||
|
# remove blank indexes
|
||||||
|
blank_index = len(vocabulary)
|
||||||
|
index_list = [index for index in index_list if index != blank_index]
|
||||||
|
# convert index list to string
|
||||||
|
return ''.join([vocabulary[index] for index in index_list])
|
||||||
|
|
||||||
|
|
||||||
|
def ctc_beam_search_decoder(probs_seq,
|
||||||
|
beam_size,
|
||||||
|
vocabulary,
|
||||||
|
cutoff_prob=1.0,
|
||||||
|
cutoff_top_n=40,
|
||||||
|
ext_scoring_func=None,
|
||||||
|
nproc=False):
|
||||||
|
"""CTC Beam search decoder.
|
||||||
|
|
||||||
|
It utilizes beam search to approximately select top best decoding
|
||||||
|
labels and returning results in the descending order.
|
||||||
|
The implementation is based on Prefix Beam Search
|
||||||
|
(https://arxiv.org/abs/1408.2873), and the unclear part is
|
||||||
|
redesigned. Two important modifications: 1) in the iterative computation
|
||||||
|
of probabilities, the assignment operation is changed to accumulation for
|
||||||
|
one prefix may comes from different paths; 2) the if condition "if l^+ not
|
||||||
|
in A_prev then" after probabilities' computation is deprecated for it is
|
||||||
|
hard to understand and seems unnecessary.
|
||||||
|
|
||||||
|
:param probs_seq: 2-D list of probability distributions over each time
|
||||||
|
step, with each element being a list of normalized
|
||||||
|
probabilities over vocabulary and blank.
|
||||||
|
:type probs_seq: 2-D list
|
||||||
|
:param beam_size: Width for beam search.
|
||||||
|
:type beam_size: int
|
||||||
|
:param vocabulary: Vocabulary list.
|
||||||
|
:type vocabulary: list
|
||||||
|
:param cutoff_prob: Cutoff probability in pruning,
|
||||||
|
default 1.0, no pruning.
|
||||||
|
:type cutoff_prob: float
|
||||||
|
:param ext_scoring_func: External scoring function for
|
||||||
|
partially decoded sentence, e.g. word count
|
||||||
|
or language model.
|
||||||
|
:type external_scoring_func: callable
|
||||||
|
:param nproc: Whether the decoder used in multiprocesses.
|
||||||
|
:type nproc: bool
|
||||||
|
:return: List of tuples of log probability and sentence as decoding
|
||||||
|
results, in descending order of the probability.
|
||||||
|
:rtype: list
|
||||||
|
"""
|
||||||
|
# dimension check
|
||||||
|
for prob_list in probs_seq:
|
||||||
|
if not len(prob_list) == len(vocabulary) + 1:
|
||||||
|
raise ValueError("The shape of prob_seq does not match with the "
|
||||||
|
"shape of the vocabulary.")
|
||||||
|
|
||||||
|
# blank_id assign
|
||||||
|
blank_id = len(vocabulary)
|
||||||
|
|
||||||
|
# If the decoder called in the multiprocesses, then use the global scorer
|
||||||
|
# instantiated in ctc_beam_search_decoder_batch().
|
||||||
|
if nproc is True:
|
||||||
|
global ext_nproc_scorer
|
||||||
|
ext_scoring_func = ext_nproc_scorer
|
||||||
|
|
||||||
|
## initialize
|
||||||
|
# prefix_set_prev: the set containing selected prefixes
|
||||||
|
# probs_b_prev: prefixes' probability ending with blank in previous step
|
||||||
|
# probs_nb_prev: prefixes' probability ending with non-blank in previous step
|
||||||
|
prefix_set_prev = {'\t': 1.0}
|
||||||
|
probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0}
|
||||||
|
|
||||||
|
## extend prefix in loop
|
||||||
|
for time_step in xrange(len(probs_seq)):
|
||||||
|
# prefix_set_next: the set containing candidate prefixes
|
||||||
|
# probs_b_cur: prefixes' probability ending with blank in current step
|
||||||
|
# probs_nb_cur: prefixes' probability ending with non-blank in current step
|
||||||
|
prefix_set_next, probs_b_cur, probs_nb_cur = {}, {}, {}
|
||||||
|
|
||||||
|
prob_idx = list(enumerate(probs_seq[time_step]))
|
||||||
|
cutoff_len = len(prob_idx)
|
||||||
|
#If pruning is enabled
|
||||||
|
if cutoff_prob < 1.0 or cutoff_top_n < cutoff_len:
|
||||||
|
prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
|
||||||
|
cutoff_len, cum_prob = 0, 0.0
|
||||||
|
for i in xrange(len(prob_idx)):
|
||||||
|
cum_prob += prob_idx[i][1]
|
||||||
|
cutoff_len += 1
|
||||||
|
if cum_prob >= cutoff_prob:
|
||||||
|
break
|
||||||
|
cutoff_len = min(cutoff_len, cutoff_top_n)
|
||||||
|
prob_idx = prob_idx[0:cutoff_len]
|
||||||
|
|
||||||
|
for l in prefix_set_prev:
|
||||||
|
if not prefix_set_next.has_key(l):
|
||||||
|
probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0
|
||||||
|
|
||||||
|
# extend prefix by travering prob_idx
|
||||||
|
for index in xrange(cutoff_len):
|
||||||
|
c, prob_c = prob_idx[index][0], prob_idx[index][1]
|
||||||
|
|
||||||
|
if c == blank_id:
|
||||||
|
probs_b_cur[l] += prob_c * (
|
||||||
|
probs_b_prev[l] + probs_nb_prev[l])
|
||||||
|
else:
|
||||||
|
last_char = l[-1]
|
||||||
|
new_char = vocabulary[c]
|
||||||
|
l_plus = l + new_char
|
||||||
|
if not prefix_set_next.has_key(l_plus):
|
||||||
|
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0
|
||||||
|
|
||||||
|
if new_char == last_char:
|
||||||
|
probs_nb_cur[l_plus] += prob_c * probs_b_prev[l]
|
||||||
|
probs_nb_cur[l] += prob_c * probs_nb_prev[l]
|
||||||
|
elif new_char == ' ':
|
||||||
|
if (ext_scoring_func is None) or (len(l) == 1):
|
||||||
|
score = 1.0
|
||||||
|
else:
|
||||||
|
prefix = l[1:]
|
||||||
|
score = ext_scoring_func(prefix)
|
||||||
|
probs_nb_cur[l_plus] += score * prob_c * (
|
||||||
|
probs_b_prev[l] + probs_nb_prev[l])
|
||||||
|
else:
|
||||||
|
probs_nb_cur[l_plus] += prob_c * (
|
||||||
|
probs_b_prev[l] + probs_nb_prev[l])
|
||||||
|
# add l_plus into prefix_set_next
|
||||||
|
prefix_set_next[l_plus] = probs_nb_cur[
|
||||||
|
l_plus] + probs_b_cur[l_plus]
|
||||||
|
# add l into prefix_set_next
|
||||||
|
prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
|
||||||
|
# update probs
|
||||||
|
probs_b_prev, probs_nb_prev = probs_b_cur, probs_nb_cur
|
||||||
|
|
||||||
|
## store top beam_size prefixes
|
||||||
|
prefix_set_prev = sorted(
|
||||||
|
prefix_set_next.iteritems(), key=lambda asd: asd[1], reverse=True)
|
||||||
|
if beam_size < len(prefix_set_prev):
|
||||||
|
prefix_set_prev = prefix_set_prev[:beam_size]
|
||||||
|
prefix_set_prev = dict(prefix_set_prev)
|
||||||
|
|
||||||
|
beam_result = []
|
||||||
|
for seq, prob in prefix_set_prev.items():
|
||||||
|
if prob > 0.0 and len(seq) > 1:
|
||||||
|
result = seq[1:]
|
||||||
|
# score last word by external scorer
|
||||||
|
if (ext_scoring_func is not None) and (result[-1] != ' '):
|
||||||
|
prob = prob * ext_scoring_func(result)
|
||||||
|
log_prob = log(prob)
|
||||||
|
beam_result.append((log_prob, result))
|
||||||
|
else:
|
||||||
|
beam_result.append((float('-inf'), ''))
|
||||||
|
|
||||||
|
## output top beam_size decoding results
|
||||||
|
beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
|
||||||
|
return beam_result
|
||||||
|
|
||||||
|
|
||||||
|
def ctc_beam_search_decoder_batch(probs_split,
|
||||||
|
beam_size,
|
||||||
|
vocabulary,
|
||||||
|
num_processes,
|
||||||
|
cutoff_prob=1.0,
|
||||||
|
cutoff_top_n=40,
|
||||||
|
ext_scoring_func=None):
|
||||||
|
"""CTC beam search decoder using multiple processes.
|
||||||
|
|
||||||
|
:param probs_seq: 3-D list with each element as an instance of 2-D list
|
||||||
|
of probabilities used by ctc_beam_search_decoder().
|
||||||
|
:type probs_seq: 3-D list
|
||||||
|
:param beam_size: Width for beam search.
|
||||||
|
:type beam_size: int
|
||||||
|
:param vocabulary: Vocabulary list.
|
||||||
|
:type vocabulary: list
|
||||||
|
:param num_processes: Number of parallel processes.
|
||||||
|
:type num_processes: int
|
||||||
|
:param cutoff_prob: Cutoff probability in pruning,
|
||||||
|
default 1.0, no pruning.
|
||||||
|
:type cutoff_prob: float
|
||||||
|
:param num_processes: Number of parallel processes.
|
||||||
|
:type num_processes: int
|
||||||
|
:param ext_scoring_func: External scoring function for
|
||||||
|
partially decoded sentence, e.g. word count
|
||||||
|
or language model.
|
||||||
|
:type external_scoring_function: callable
|
||||||
|
:return: List of tuples of log probability and sentence as decoding
|
||||||
|
results, in descending order of the probability.
|
||||||
|
:rtype: list
|
||||||
|
"""
|
||||||
|
if not num_processes > 0:
|
||||||
|
raise ValueError("Number of processes must be positive!")
|
||||||
|
|
||||||
|
# use global variable to pass the externnal scorer to beam search decoder
|
||||||
|
global ext_nproc_scorer
|
||||||
|
ext_nproc_scorer = ext_scoring_func
|
||||||
|
nproc = True
|
||||||
|
|
||||||
|
pool = multiprocessing.Pool(processes=num_processes)
|
||||||
|
results = []
|
||||||
|
for i, probs_list in enumerate(probs_split):
|
||||||
|
args = (probs_list, beam_size, vocabulary, cutoff_prob, cutoff_top_n,
|
||||||
|
None, nproc)
|
||||||
|
results.append(pool.apply_async(ctc_beam_search_decoder, args))
|
||||||
|
|
||||||
|
pool.close()
|
||||||
|
pool.join()
|
||||||
|
beam_search_results = [result.get() for result in results]
|
||||||
|
return beam_search_results
|
@ -0,0 +1,68 @@
|
|||||||
|
"""External Scorer for Beam Search Decoder."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import kenlm
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class Scorer(object):
|
||||||
|
"""External scorer to evaluate a prefix or whole sentence in
|
||||||
|
beam search decoding, including the score from n-gram language
|
||||||
|
model and word count.
|
||||||
|
|
||||||
|
:param alpha: Parameter associated with language model. Don't use
|
||||||
|
language model when alpha = 0.
|
||||||
|
:type alpha: float
|
||||||
|
:param beta: Parameter associated with word count. Don't use word
|
||||||
|
count when beta = 0.
|
||||||
|
:type beta: float
|
||||||
|
:model_path: Path to load language model.
|
||||||
|
:type model_path: basestring
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, alpha, beta, model_path):
|
||||||
|
self._alpha = alpha
|
||||||
|
self._beta = beta
|
||||||
|
if not os.path.isfile(model_path):
|
||||||
|
raise IOError("Invaid language model path: %s" % model_path)
|
||||||
|
self._language_model = kenlm.LanguageModel(model_path)
|
||||||
|
|
||||||
|
# n-gram language model scoring
|
||||||
|
def _language_model_score(self, sentence):
|
||||||
|
#log10 prob of last word
|
||||||
|
log_cond_prob = list(
|
||||||
|
self._language_model.full_scores(sentence, eos=False))[-1][0]
|
||||||
|
return np.power(10, log_cond_prob)
|
||||||
|
|
||||||
|
# word insertion term
|
||||||
|
def _word_count(self, sentence):
|
||||||
|
words = sentence.strip().split(' ')
|
||||||
|
return len(words)
|
||||||
|
|
||||||
|
# reset alpha and beta
|
||||||
|
def reset_params(self, alpha, beta):
|
||||||
|
self._alpha = alpha
|
||||||
|
self._beta = beta
|
||||||
|
|
||||||
|
# execute evaluation
|
||||||
|
def __call__(self, sentence, log=False):
|
||||||
|
"""Evaluation function, gathering all the different scores
|
||||||
|
and return the final one.
|
||||||
|
|
||||||
|
:param sentence: The input sentence for evalutation
|
||||||
|
:type sentence: basestring
|
||||||
|
:param log: Whether return the score in log representation.
|
||||||
|
:type log: bool
|
||||||
|
:return: Evaluation score, in the decimal or log.
|
||||||
|
:rtype: float
|
||||||
|
"""
|
||||||
|
lm = self._language_model_score(sentence)
|
||||||
|
word_cnt = self._word_count(sentence)
|
||||||
|
if log == False:
|
||||||
|
score = np.power(lm, self._alpha) * np.power(word_cnt, self._beta)
|
||||||
|
else:
|
||||||
|
score = self._alpha * np.log(lm) + self._beta * np.log(word_cnt)
|
||||||
|
return score
|
@ -0,0 +1,19 @@
|
|||||||
|
"""Set up paths for DS2"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os.path
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def add_path(path):
|
||||||
|
if path not in sys.path:
|
||||||
|
sys.path.insert(0, path)
|
||||||
|
|
||||||
|
|
||||||
|
this_dir = os.path.dirname(__file__)
|
||||||
|
|
||||||
|
# Add project path to PYTHONPATH
|
||||||
|
proj_path = os.path.join(this_dir, '..')
|
||||||
|
add_path(proj_path)
|
@ -0,0 +1,222 @@
|
|||||||
|
#include "ctc_beam_search_decoder.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <iostream>
|
||||||
|
#include <limits>
|
||||||
|
#include <map>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "ThreadPool.h"
|
||||||
|
#include "fst/fstlib.h"
|
||||||
|
|
||||||
|
#include "decoder_utils.h"
|
||||||
|
#include "path_trie.h"
|
||||||
|
|
||||||
|
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
||||||
|
|
||||||
|
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
|
||||||
|
const std::vector<std::vector<double>> &probs_seq,
|
||||||
|
const std::vector<std::string> &vocabulary,
|
||||||
|
size_t beam_size,
|
||||||
|
double cutoff_prob,
|
||||||
|
size_t cutoff_top_n,
|
||||||
|
Scorer *ext_scorer) {
|
||||||
|
// dimension check
|
||||||
|
size_t num_time_steps = probs_seq.size();
|
||||||
|
for (size_t i = 0; i < num_time_steps; ++i) {
|
||||||
|
VALID_CHECK_EQ(probs_seq[i].size(),
|
||||||
|
vocabulary.size() + 1,
|
||||||
|
"The shape of probs_seq does not match with "
|
||||||
|
"the shape of the vocabulary");
|
||||||
|
}
|
||||||
|
|
||||||
|
// assign blank id
|
||||||
|
size_t blank_id = vocabulary.size();
|
||||||
|
|
||||||
|
// assign space id
|
||||||
|
auto it = std::find(vocabulary.begin(), vocabulary.end(), " ");
|
||||||
|
int space_id = it - vocabulary.begin();
|
||||||
|
// if no space in vocabulary
|
||||||
|
if ((size_t)space_id >= vocabulary.size()) {
|
||||||
|
space_id = -2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// init prefixes' root
|
||||||
|
PathTrie root;
|
||||||
|
root.score = root.log_prob_b_prev = 0.0;
|
||||||
|
std::vector<PathTrie *> prefixes;
|
||||||
|
prefixes.push_back(&root);
|
||||||
|
|
||||||
|
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
||||||
|
auto fst_dict = static_cast<fst::StdVectorFst *>(ext_scorer->dictionary);
|
||||||
|
fst::StdVectorFst *dict_ptr = fst_dict->Copy(true);
|
||||||
|
root.set_dictionary(dict_ptr);
|
||||||
|
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
|
||||||
|
root.set_matcher(matcher);
|
||||||
|
}
|
||||||
|
|
||||||
|
// prefix search over time
|
||||||
|
for (size_t time_step = 0; time_step < num_time_steps; ++time_step) {
|
||||||
|
auto &prob = probs_seq[time_step];
|
||||||
|
|
||||||
|
float min_cutoff = -NUM_FLT_INF;
|
||||||
|
bool full_beam = false;
|
||||||
|
if (ext_scorer != nullptr) {
|
||||||
|
size_t num_prefixes = std::min(prefixes.size(), beam_size);
|
||||||
|
std::sort(
|
||||||
|
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
|
||||||
|
min_cutoff = prefixes[num_prefixes - 1]->score +
|
||||||
|
std::log(prob[blank_id]) - std::max(0.0, ext_scorer->beta);
|
||||||
|
full_beam = (num_prefixes == beam_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<size_t, float>> log_prob_idx =
|
||||||
|
get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n);
|
||||||
|
// loop over chars
|
||||||
|
for (size_t index = 0; index < log_prob_idx.size(); index++) {
|
||||||
|
auto c = log_prob_idx[index].first;
|
||||||
|
auto log_prob_c = log_prob_idx[index].second;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) {
|
||||||
|
auto prefix = prefixes[i];
|
||||||
|
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// blank
|
||||||
|
if (c == blank_id) {
|
||||||
|
prefix->log_prob_b_cur =
|
||||||
|
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// repeated character
|
||||||
|
if (c == prefix->character) {
|
||||||
|
prefix->log_prob_nb_cur = log_sum_exp(
|
||||||
|
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
|
||||||
|
}
|
||||||
|
// get new prefix
|
||||||
|
auto prefix_new = prefix->get_path_trie(c);
|
||||||
|
|
||||||
|
if (prefix_new != nullptr) {
|
||||||
|
float log_p = -NUM_FLT_INF;
|
||||||
|
|
||||||
|
if (c == prefix->character &&
|
||||||
|
prefix->log_prob_b_prev > -NUM_FLT_INF) {
|
||||||
|
log_p = log_prob_c + prefix->log_prob_b_prev;
|
||||||
|
} else if (c != prefix->character) {
|
||||||
|
log_p = log_prob_c + prefix->score;
|
||||||
|
}
|
||||||
|
|
||||||
|
// language model scoring
|
||||||
|
if (ext_scorer != nullptr &&
|
||||||
|
(c == space_id || ext_scorer->is_character_based())) {
|
||||||
|
PathTrie *prefix_to_score = nullptr;
|
||||||
|
// skip scoring the space
|
||||||
|
if (ext_scorer->is_character_based()) {
|
||||||
|
prefix_to_score = prefix_new;
|
||||||
|
} else {
|
||||||
|
prefix_to_score = prefix;
|
||||||
|
}
|
||||||
|
|
||||||
|
float score = 0.0;
|
||||||
|
std::vector<std::string> ngram;
|
||||||
|
ngram = ext_scorer->make_ngram(prefix_to_score);
|
||||||
|
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
|
||||||
|
log_p += score;
|
||||||
|
log_p += ext_scorer->beta;
|
||||||
|
}
|
||||||
|
prefix_new->log_prob_nb_cur =
|
||||||
|
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
|
||||||
|
}
|
||||||
|
} // end of loop over prefix
|
||||||
|
} // end of loop over vocabulary
|
||||||
|
|
||||||
|
|
||||||
|
prefixes.clear();
|
||||||
|
// update log probs
|
||||||
|
root.iterate_to_vec(prefixes);
|
||||||
|
|
||||||
|
// only preserve top beam_size prefixes
|
||||||
|
if (prefixes.size() >= beam_size) {
|
||||||
|
std::nth_element(prefixes.begin(),
|
||||||
|
prefixes.begin() + beam_size,
|
||||||
|
prefixes.end(),
|
||||||
|
prefix_compare);
|
||||||
|
for (size_t i = beam_size; i < prefixes.size(); ++i) {
|
||||||
|
prefixes[i]->remove();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // end of loop over time
|
||||||
|
|
||||||
|
// score the last word of each prefix that doesn't end with space
|
||||||
|
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
|
||||||
|
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
||||||
|
auto prefix = prefixes[i];
|
||||||
|
if (!prefix->is_empty() && prefix->character != space_id) {
|
||||||
|
float score = 0.0;
|
||||||
|
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
|
||||||
|
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
|
||||||
|
score += ext_scorer->beta;
|
||||||
|
prefix->score += score;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t num_prefixes = std::min(prefixes.size(), beam_size);
|
||||||
|
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
|
||||||
|
|
||||||
|
// compute aproximate ctc score as the return score, without affecting the
|
||||||
|
// return order of decoding result. To delete when decoder gets stable.
|
||||||
|
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
||||||
|
double approx_ctc = prefixes[i]->score;
|
||||||
|
if (ext_scorer != nullptr) {
|
||||||
|
std::vector<int> output;
|
||||||
|
prefixes[i]->get_path_vec(output);
|
||||||
|
auto prefix_length = output.size();
|
||||||
|
auto words = ext_scorer->split_labels(output);
|
||||||
|
// remove word insert
|
||||||
|
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
|
||||||
|
// remove language model weight:
|
||||||
|
approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
|
||||||
|
}
|
||||||
|
prefixes[i]->approx_ctc = approx_ctc;
|
||||||
|
}
|
||||||
|
|
||||||
|
return get_beam_search_result(prefixes, vocabulary, beam_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<std::vector<std::pair<double, std::string>>>
|
||||||
|
ctc_beam_search_decoder_batch(
|
||||||
|
const std::vector<std::vector<std::vector<double>>> &probs_split,
|
||||||
|
const std::vector<std::string> &vocabulary,
|
||||||
|
size_t beam_size,
|
||||||
|
size_t num_processes,
|
||||||
|
double cutoff_prob,
|
||||||
|
size_t cutoff_top_n,
|
||||||
|
Scorer *ext_scorer) {
|
||||||
|
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
|
||||||
|
// thread pool
|
||||||
|
ThreadPool pool(num_processes);
|
||||||
|
// number of samples
|
||||||
|
size_t batch_size = probs_split.size();
|
||||||
|
|
||||||
|
// enqueue the tasks of decoding
|
||||||
|
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
|
||||||
|
for (size_t i = 0; i < batch_size; ++i) {
|
||||||
|
res.emplace_back(pool.enqueue(ctc_beam_search_decoder,
|
||||||
|
probs_split[i],
|
||||||
|
vocabulary,
|
||||||
|
beam_size,
|
||||||
|
cutoff_prob,
|
||||||
|
cutoff_top_n,
|
||||||
|
ext_scorer));
|
||||||
|
}
|
||||||
|
|
||||||
|
// get decoding results
|
||||||
|
std::vector<std::vector<std::pair<double, std::string>>> batch_results;
|
||||||
|
for (size_t i = 0; i < batch_size; ++i) {
|
||||||
|
batch_results.emplace_back(res[i].get());
|
||||||
|
}
|
||||||
|
return batch_results;
|
||||||
|
}
|
@ -0,0 +1,61 @@
|
|||||||
|
#ifndef CTC_BEAM_SEARCH_DECODER_H_
|
||||||
|
#define CTC_BEAM_SEARCH_DECODER_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "scorer.h"
|
||||||
|
|
||||||
|
/* CTC Beam Search Decoder
|
||||||
|
|
||||||
|
* Parameters:
|
||||||
|
* probs_seq: 2-D vector that each element is a vector of probabilities
|
||||||
|
* over vocabulary of one time step.
|
||||||
|
* vocabulary: A vector of vocabulary.
|
||||||
|
* beam_size: The width of beam search.
|
||||||
|
* cutoff_prob: Cutoff probability for pruning.
|
||||||
|
* cutoff_top_n: Cutoff number for pruning.
|
||||||
|
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
||||||
|
* n-gram language model scoring and word insertion term.
|
||||||
|
* Default null, decoding the input sample without scorer.
|
||||||
|
* Return:
|
||||||
|
* A vector that each element is a pair of score and decoding result,
|
||||||
|
* in desending order.
|
||||||
|
*/
|
||||||
|
std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
|
||||||
|
const std::vector<std::vector<double>> &probs_seq,
|
||||||
|
const std::vector<std::string> &vocabulary,
|
||||||
|
size_t beam_size,
|
||||||
|
double cutoff_prob = 1.0,
|
||||||
|
size_t cutoff_top_n = 40,
|
||||||
|
Scorer *ext_scorer = nullptr);
|
||||||
|
|
||||||
|
/* CTC Beam Search Decoder for batch data
|
||||||
|
|
||||||
|
* Parameters:
|
||||||
|
* probs_seq: 3-D vector that each element is a 2-D vector that can be used
|
||||||
|
* by ctc_beam_search_decoder().
|
||||||
|
* vocabulary: A vector of vocabulary.
|
||||||
|
* beam_size: The width of beam search.
|
||||||
|
* num_processes: Number of threads for beam search.
|
||||||
|
* cutoff_prob: Cutoff probability for pruning.
|
||||||
|
* cutoff_top_n: Cutoff number for pruning.
|
||||||
|
* ext_scorer: External scorer to evaluate a prefix, which consists of
|
||||||
|
* n-gram language model scoring and word insertion term.
|
||||||
|
* Default null, decoding the input sample without scorer.
|
||||||
|
* Return:
|
||||||
|
* A 2-D vector that each element is a vector of beam search decoding
|
||||||
|
* result for one audio sample.
|
||||||
|
*/
|
||||||
|
std::vector<std::vector<std::pair<double, std::string>>>
|
||||||
|
ctc_beam_search_decoder_batch(
|
||||||
|
const std::vector<std::vector<std::vector<double>>> &probs_split,
|
||||||
|
const std::vector<std::string> &vocabulary,
|
||||||
|
size_t beam_size,
|
||||||
|
size_t num_processes,
|
||||||
|
double cutoff_prob = 1.0,
|
||||||
|
size_t cutoff_top_n = 40,
|
||||||
|
Scorer *ext_scorer = nullptr);
|
||||||
|
|
||||||
|
#endif // CTC_BEAM_SEARCH_DECODER_H_
|
@ -0,0 +1,45 @@
|
|||||||
|
#include "ctc_greedy_decoder.h"
|
||||||
|
#include "decoder_utils.h"
|
||||||
|
|
||||||
|
std::string ctc_greedy_decoder(
|
||||||
|
const std::vector<std::vector<double>> &probs_seq,
|
||||||
|
const std::vector<std::string> &vocabulary) {
|
||||||
|
// dimension check
|
||||||
|
size_t num_time_steps = probs_seq.size();
|
||||||
|
for (size_t i = 0; i < num_time_steps; ++i) {
|
||||||
|
VALID_CHECK_EQ(probs_seq[i].size(),
|
||||||
|
vocabulary.size() + 1,
|
||||||
|
"The shape of probs_seq does not match with "
|
||||||
|
"the shape of the vocabulary");
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t blank_id = vocabulary.size();
|
||||||
|
|
||||||
|
std::vector<size_t> max_idx_vec(num_time_steps, 0);
|
||||||
|
std::vector<size_t> idx_vec;
|
||||||
|
for (size_t i = 0; i < num_time_steps; ++i) {
|
||||||
|
double max_prob = 0.0;
|
||||||
|
size_t max_idx = 0;
|
||||||
|
const std::vector<double> &probs_step = probs_seq[i];
|
||||||
|
for (size_t j = 0; j < probs_step.size(); ++j) {
|
||||||
|
if (max_prob < probs_step[j]) {
|
||||||
|
max_idx = j;
|
||||||
|
max_prob = probs_step[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// id with maximum probability in current time step
|
||||||
|
max_idx_vec[i] = max_idx;
|
||||||
|
// deduplicate
|
||||||
|
if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) {
|
||||||
|
idx_vec.push_back(max_idx_vec[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string best_path_result;
|
||||||
|
for (size_t i = 0; i < idx_vec.size(); ++i) {
|
||||||
|
if (idx_vec[i] != blank_id) {
|
||||||
|
best_path_result += vocabulary[idx_vec[i]];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return best_path_result;
|
||||||
|
}
|
@ -0,0 +1,20 @@
|
|||||||
|
#ifndef CTC_GREEDY_DECODER_H
|
||||||
|
#define CTC_GREEDY_DECODER_H
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
/* CTC Greedy (Best Path) Decoder
|
||||||
|
*
|
||||||
|
* Parameters:
|
||||||
|
* probs_seq: 2-D vector that each element is a vector of probabilities
|
||||||
|
* over vocabulary of one time step.
|
||||||
|
* vocabulary: A vector of vocabulary.
|
||||||
|
* Return:
|
||||||
|
* The decoding result in string
|
||||||
|
*/
|
||||||
|
std::string ctc_greedy_decoder(
|
||||||
|
const std::vector<std::vector<double>>& probs_seq,
|
||||||
|
const std::vector<std::string>& vocabulary);
|
||||||
|
|
||||||
|
#endif // CTC_GREEDY_DECODER_H
|
@ -0,0 +1,176 @@
|
|||||||
|
#include "decoder_utils.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
std::vector<std::pair<size_t, float>> get_pruned_log_probs(
|
||||||
|
const std::vector<double> &prob_step,
|
||||||
|
double cutoff_prob,
|
||||||
|
size_t cutoff_top_n) {
|
||||||
|
std::vector<std::pair<int, double>> prob_idx;
|
||||||
|
for (size_t i = 0; i < prob_step.size(); ++i) {
|
||||||
|
prob_idx.push_back(std::pair<int, double>(i, prob_step[i]));
|
||||||
|
}
|
||||||
|
// pruning of vacobulary
|
||||||
|
size_t cutoff_len = prob_step.size();
|
||||||
|
if (cutoff_prob < 1.0 || cutoff_top_n < cutoff_len) {
|
||||||
|
std::sort(
|
||||||
|
prob_idx.begin(), prob_idx.end(), pair_comp_second_rev<int, double>);
|
||||||
|
if (cutoff_prob < 1.0) {
|
||||||
|
double cum_prob = 0.0;
|
||||||
|
cutoff_len = 0;
|
||||||
|
for (size_t i = 0; i < prob_idx.size(); ++i) {
|
||||||
|
cum_prob += prob_idx[i].second;
|
||||||
|
cutoff_len += 1;
|
||||||
|
if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n) break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
prob_idx = std::vector<std::pair<int, double>>(
|
||||||
|
prob_idx.begin(), prob_idx.begin() + cutoff_len);
|
||||||
|
}
|
||||||
|
std::vector<std::pair<size_t, float>> log_prob_idx;
|
||||||
|
for (size_t i = 0; i < cutoff_len; ++i) {
|
||||||
|
log_prob_idx.push_back(std::pair<int, float>(
|
||||||
|
prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN)));
|
||||||
|
}
|
||||||
|
return log_prob_idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<std::pair<double, std::string>> get_beam_search_result(
|
||||||
|
const std::vector<PathTrie *> &prefixes,
|
||||||
|
const std::vector<std::string> &vocabulary,
|
||||||
|
size_t beam_size) {
|
||||||
|
// allow for the post processing
|
||||||
|
std::vector<PathTrie *> space_prefixes;
|
||||||
|
if (space_prefixes.empty()) {
|
||||||
|
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
|
||||||
|
space_prefixes.push_back(prefixes[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
|
||||||
|
std::vector<std::pair<double, std::string>> output_vecs;
|
||||||
|
for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) {
|
||||||
|
std::vector<int> output;
|
||||||
|
space_prefixes[i]->get_path_vec(output);
|
||||||
|
// convert index to string
|
||||||
|
std::string output_str;
|
||||||
|
for (size_t j = 0; j < output.size(); j++) {
|
||||||
|
output_str += vocabulary[output[j]];
|
||||||
|
}
|
||||||
|
std::pair<double, std::string> output_pair(-space_prefixes[i]->approx_ctc,
|
||||||
|
output_str);
|
||||||
|
output_vecs.emplace_back(output_pair);
|
||||||
|
}
|
||||||
|
|
||||||
|
return output_vecs;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t get_utf8_str_len(const std::string &str) {
|
||||||
|
size_t str_len = 0;
|
||||||
|
for (char c : str) {
|
||||||
|
str_len += ((c & 0xc0) != 0x80);
|
||||||
|
}
|
||||||
|
return str_len;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> split_utf8_str(const std::string &str) {
|
||||||
|
std::vector<std::string> result;
|
||||||
|
std::string out_str;
|
||||||
|
|
||||||
|
for (char c : str) {
|
||||||
|
if ((c & 0xc0) != 0x80) // new UTF-8 character
|
||||||
|
{
|
||||||
|
if (!out_str.empty()) {
|
||||||
|
result.push_back(out_str);
|
||||||
|
out_str.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out_str.append(1, c);
|
||||||
|
}
|
||||||
|
result.push_back(out_str);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> split_str(const std::string &s,
|
||||||
|
const std::string &delim) {
|
||||||
|
std::vector<std::string> result;
|
||||||
|
std::size_t start = 0, delim_len = delim.size();
|
||||||
|
while (true) {
|
||||||
|
std::size_t end = s.find(delim, start);
|
||||||
|
if (end == std::string::npos) {
|
||||||
|
if (start < s.size()) {
|
||||||
|
result.push_back(s.substr(start));
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (end > start) {
|
||||||
|
result.push_back(s.substr(start, end - start));
|
||||||
|
}
|
||||||
|
start = end + delim_len;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool prefix_compare(const PathTrie *x, const PathTrie *y) {
|
||||||
|
if (x->score == y->score) {
|
||||||
|
if (x->character == y->character) {
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
return (x->character < y->character);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return x->score > y->score;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void add_word_to_fst(const std::vector<int> &word,
|
||||||
|
fst::StdVectorFst *dictionary) {
|
||||||
|
if (dictionary->NumStates() == 0) {
|
||||||
|
fst::StdVectorFst::StateId start = dictionary->AddState();
|
||||||
|
assert(start == 0);
|
||||||
|
dictionary->SetStart(start);
|
||||||
|
}
|
||||||
|
fst::StdVectorFst::StateId src = dictionary->Start();
|
||||||
|
fst::StdVectorFst::StateId dst;
|
||||||
|
for (auto c : word) {
|
||||||
|
dst = dictionary->AddState();
|
||||||
|
dictionary->AddArc(src, fst::StdArc(c, c, 0, dst));
|
||||||
|
src = dst;
|
||||||
|
}
|
||||||
|
dictionary->SetFinal(dst, fst::StdArc::Weight::One());
|
||||||
|
}
|
||||||
|
|
||||||
|
bool add_word_to_dictionary(
|
||||||
|
const std::string &word,
|
||||||
|
const std::unordered_map<std::string, int> &char_map,
|
||||||
|
bool add_space,
|
||||||
|
int SPACE_ID,
|
||||||
|
fst::StdVectorFst *dictionary) {
|
||||||
|
auto characters = split_utf8_str(word);
|
||||||
|
|
||||||
|
std::vector<int> int_word;
|
||||||
|
|
||||||
|
for (auto &c : characters) {
|
||||||
|
if (c == " ") {
|
||||||
|
int_word.push_back(SPACE_ID);
|
||||||
|
} else {
|
||||||
|
auto int_c = char_map.find(c);
|
||||||
|
if (int_c != char_map.end()) {
|
||||||
|
int_word.push_back(int_c->second);
|
||||||
|
} else {
|
||||||
|
return false; // return without adding
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (add_space) {
|
||||||
|
int_word.push_back(SPACE_ID);
|
||||||
|
}
|
||||||
|
|
||||||
|
add_word_to_fst(int_word, dictionary);
|
||||||
|
return true; // return with successful adding
|
||||||
|
}
|
@ -0,0 +1,94 @@
|
|||||||
|
#ifndef DECODER_UTILS_H_
|
||||||
|
#define DECODER_UTILS_H_
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
#include "fst/log.h"
|
||||||
|
#include "path_trie.h"
|
||||||
|
|
||||||
|
const float NUM_FLT_INF = std::numeric_limits<float>::max();
|
||||||
|
const float NUM_FLT_MIN = std::numeric_limits<float>::min();
|
||||||
|
|
||||||
|
// inline function for validation check
|
||||||
|
inline void check(
|
||||||
|
bool x, const char *expr, const char *file, int line, const char *err) {
|
||||||
|
if (!x) {
|
||||||
|
std::cout << "[" << file << ":" << line << "] ";
|
||||||
|
LOG(FATAL) << "\"" << expr << "\" check failed. " << err;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define VALID_CHECK(x, info) \
|
||||||
|
check(static_cast<bool>(x), #x, __FILE__, __LINE__, info)
|
||||||
|
#define VALID_CHECK_EQ(x, y, info) VALID_CHECK((x) == (y), info)
|
||||||
|
#define VALID_CHECK_GT(x, y, info) VALID_CHECK((x) > (y), info)
|
||||||
|
#define VALID_CHECK_LT(x, y, info) VALID_CHECK((x) < (y), info)
|
||||||
|
|
||||||
|
|
||||||
|
// Function template for comparing two pairs
|
||||||
|
template <typename T1, typename T2>
|
||||||
|
bool pair_comp_first_rev(const std::pair<T1, T2> &a,
|
||||||
|
const std::pair<T1, T2> &b) {
|
||||||
|
return a.first > b.first;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function template for comparing two pairs
|
||||||
|
template <typename T1, typename T2>
|
||||||
|
bool pair_comp_second_rev(const std::pair<T1, T2> &a,
|
||||||
|
const std::pair<T1, T2> &b) {
|
||||||
|
return a.second > b.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the sum of two probabilities in log scale
|
||||||
|
template <typename T>
|
||||||
|
T log_sum_exp(const T &x, const T &y) {
|
||||||
|
static T num_min = -std::numeric_limits<T>::max();
|
||||||
|
if (x <= num_min) return y;
|
||||||
|
if (y <= num_min) return x;
|
||||||
|
T xmax = std::max(x, y);
|
||||||
|
return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get pruned probability vector for each time step's beam search
|
||||||
|
std::vector<std::pair<size_t, float>> get_pruned_log_probs(
|
||||||
|
const std::vector<double> &prob_step,
|
||||||
|
double cutoff_prob,
|
||||||
|
size_t cutoff_top_n);
|
||||||
|
|
||||||
|
// Get beam search result from prefixes in trie tree
|
||||||
|
std::vector<std::pair<double, std::string>> get_beam_search_result(
|
||||||
|
const std::vector<PathTrie *> &prefixes,
|
||||||
|
const std::vector<std::string> &vocabulary,
|
||||||
|
size_t beam_size);
|
||||||
|
|
||||||
|
// Functor for prefix comparsion
|
||||||
|
bool prefix_compare(const PathTrie *x, const PathTrie *y);
|
||||||
|
|
||||||
|
/* Get length of utf8 encoding string
|
||||||
|
* See: http://stackoverflow.com/a/4063229
|
||||||
|
*/
|
||||||
|
size_t get_utf8_str_len(const std::string &str);
|
||||||
|
|
||||||
|
/* Split a string into a list of strings on a given string
|
||||||
|
* delimiter. NB: delimiters on beginning / end of string are
|
||||||
|
* trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"].
|
||||||
|
*/
|
||||||
|
std::vector<std::string> split_str(const std::string &s,
|
||||||
|
const std::string &delim);
|
||||||
|
|
||||||
|
/* Splits string into vector of strings representing
|
||||||
|
* UTF-8 characters (not same as chars)
|
||||||
|
*/
|
||||||
|
std::vector<std::string> split_utf8_str(const std::string &str);
|
||||||
|
|
||||||
|
// Add a word in index to the dicionary of fst
|
||||||
|
void add_word_to_fst(const std::vector<int> &word,
|
||||||
|
fst::StdVectorFst *dictionary);
|
||||||
|
|
||||||
|
// Add a word in string to dictionary
|
||||||
|
bool add_word_to_dictionary(
|
||||||
|
const std::string &word,
|
||||||
|
const std::unordered_map<std::string, int> &char_map,
|
||||||
|
bool add_space,
|
||||||
|
int SPACE_ID,
|
||||||
|
fst::StdVectorFst *dictionary);
|
||||||
|
#endif // DECODER_UTILS_H
|
@ -0,0 +1,33 @@
|
|||||||
|
%module swig_decoders
|
||||||
|
%{
|
||||||
|
#include "scorer.h"
|
||||||
|
#include "ctc_greedy_decoder.h"
|
||||||
|
#include "ctc_beam_search_decoder.h"
|
||||||
|
#include "decoder_utils.h"
|
||||||
|
%}
|
||||||
|
|
||||||
|
%include "std_vector.i"
|
||||||
|
%include "std_pair.i"
|
||||||
|
%include "std_string.i"
|
||||||
|
%import "decoder_utils.h"
|
||||||
|
|
||||||
|
namespace std {
|
||||||
|
%template(DoubleVector) std::vector<double>;
|
||||||
|
%template(IntVector) std::vector<int>;
|
||||||
|
%template(StringVector) std::vector<std::string>;
|
||||||
|
%template(VectorOfStructVector) std::vector<std::vector<double> >;
|
||||||
|
%template(FloatVector) std::vector<float>;
|
||||||
|
%template(Pair) std::pair<float, std::string>;
|
||||||
|
%template(PairFloatStringVector) std::vector<std::pair<float, std::string> >;
|
||||||
|
%template(PairDoubleStringVector) std::vector<std::pair<double, std::string> >;
|
||||||
|
%template(PairDoubleStringVector2) std::vector<std::vector<std::pair<double, std::string> > >;
|
||||||
|
%template(DoubleVector3) std::vector<std::vector<std::vector<double> > >;
|
||||||
|
}
|
||||||
|
|
||||||
|
%template(IntDoublePairCompSecondRev) pair_comp_second_rev<int, double>;
|
||||||
|
%template(StringDoublePairCompSecondRev) pair_comp_second_rev<std::string, double>;
|
||||||
|
%template(DoubleStringPairCompFirstRev) pair_comp_first_rev<double, std::string>;
|
||||||
|
|
||||||
|
%include "scorer.h"
|
||||||
|
%include "ctc_greedy_decoder.h"
|
||||||
|
%include "ctc_beam_search_decoder.h"
|
@ -0,0 +1,148 @@
|
|||||||
|
#include "path_trie.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <limits>
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "decoder_utils.h"
|
||||||
|
|
||||||
|
PathTrie::PathTrie() {
|
||||||
|
log_prob_b_prev = -NUM_FLT_INF;
|
||||||
|
log_prob_nb_prev = -NUM_FLT_INF;
|
||||||
|
log_prob_b_cur = -NUM_FLT_INF;
|
||||||
|
log_prob_nb_cur = -NUM_FLT_INF;
|
||||||
|
score = -NUM_FLT_INF;
|
||||||
|
|
||||||
|
ROOT_ = -1;
|
||||||
|
character = ROOT_;
|
||||||
|
exists_ = true;
|
||||||
|
parent = nullptr;
|
||||||
|
|
||||||
|
dictionary_ = nullptr;
|
||||||
|
dictionary_state_ = 0;
|
||||||
|
has_dictionary_ = false;
|
||||||
|
|
||||||
|
matcher_ = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
PathTrie::~PathTrie() {
|
||||||
|
for (auto child : children_) {
|
||||||
|
delete child.second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PathTrie* PathTrie::get_path_trie(int new_char, bool reset) {
|
||||||
|
auto child = children_.begin();
|
||||||
|
for (child = children_.begin(); child != children_.end(); ++child) {
|
||||||
|
if (child->first == new_char) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (child != children_.end()) {
|
||||||
|
if (!child->second->exists_) {
|
||||||
|
child->second->exists_ = true;
|
||||||
|
child->second->log_prob_b_prev = -NUM_FLT_INF;
|
||||||
|
child->second->log_prob_nb_prev = -NUM_FLT_INF;
|
||||||
|
child->second->log_prob_b_cur = -NUM_FLT_INF;
|
||||||
|
child->second->log_prob_nb_cur = -NUM_FLT_INF;
|
||||||
|
}
|
||||||
|
return (child->second);
|
||||||
|
} else {
|
||||||
|
if (has_dictionary_) {
|
||||||
|
matcher_->SetState(dictionary_state_);
|
||||||
|
bool found = matcher_->Find(new_char + 1);
|
||||||
|
if (!found) {
|
||||||
|
// Adding this character causes word outside dictionary
|
||||||
|
auto FSTZERO = fst::TropicalWeight::Zero();
|
||||||
|
auto final_weight = dictionary_->Final(dictionary_state_);
|
||||||
|
bool is_final = (final_weight != FSTZERO);
|
||||||
|
if (is_final && reset) {
|
||||||
|
dictionary_state_ = dictionary_->Start();
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
} else {
|
||||||
|
PathTrie* new_path = new PathTrie;
|
||||||
|
new_path->character = new_char;
|
||||||
|
new_path->parent = this;
|
||||||
|
new_path->dictionary_ = dictionary_;
|
||||||
|
new_path->dictionary_state_ = matcher_->Value().nextstate;
|
||||||
|
new_path->has_dictionary_ = true;
|
||||||
|
new_path->matcher_ = matcher_;
|
||||||
|
children_.push_back(std::make_pair(new_char, new_path));
|
||||||
|
return new_path;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
PathTrie* new_path = new PathTrie;
|
||||||
|
new_path->character = new_char;
|
||||||
|
new_path->parent = this;
|
||||||
|
children_.push_back(std::make_pair(new_char, new_path));
|
||||||
|
return new_path;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
PathTrie* PathTrie::get_path_vec(std::vector<int>& output) {
|
||||||
|
return get_path_vec(output, ROOT_);
|
||||||
|
}
|
||||||
|
|
||||||
|
PathTrie* PathTrie::get_path_vec(std::vector<int>& output,
|
||||||
|
int stop,
|
||||||
|
size_t max_steps) {
|
||||||
|
if (character == stop || character == ROOT_ || output.size() == max_steps) {
|
||||||
|
std::reverse(output.begin(), output.end());
|
||||||
|
return this;
|
||||||
|
} else {
|
||||||
|
output.push_back(character);
|
||||||
|
return parent->get_path_vec(output, stop, max_steps);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void PathTrie::iterate_to_vec(std::vector<PathTrie*>& output) {
|
||||||
|
if (exists_) {
|
||||||
|
log_prob_b_prev = log_prob_b_cur;
|
||||||
|
log_prob_nb_prev = log_prob_nb_cur;
|
||||||
|
|
||||||
|
log_prob_b_cur = -NUM_FLT_INF;
|
||||||
|
log_prob_nb_cur = -NUM_FLT_INF;
|
||||||
|
|
||||||
|
score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev);
|
||||||
|
output.push_back(this);
|
||||||
|
}
|
||||||
|
for (auto child : children_) {
|
||||||
|
child.second->iterate_to_vec(output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void PathTrie::remove() {
|
||||||
|
exists_ = false;
|
||||||
|
|
||||||
|
if (children_.size() == 0) {
|
||||||
|
auto child = parent->children_.begin();
|
||||||
|
for (child = parent->children_.begin(); child != parent->children_.end();
|
||||||
|
++child) {
|
||||||
|
if (child->first == character) {
|
||||||
|
parent->children_.erase(child);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parent->children_.size() == 0 && !parent->exists_) {
|
||||||
|
parent->remove();
|
||||||
|
}
|
||||||
|
|
||||||
|
delete this;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) {
|
||||||
|
dictionary_ = dictionary;
|
||||||
|
dictionary_state_ = dictionary->Start();
|
||||||
|
has_dictionary_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
|
||||||
|
void PathTrie::set_matcher(std::shared_ptr<FSTMATCH> matcher) {
|
||||||
|
matcher_ = matcher;
|
||||||
|
}
|
@ -0,0 +1,67 @@
|
|||||||
|
#ifndef PATH_TRIE_H
|
||||||
|
#define PATH_TRIE_H
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <limits>
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "fst/fstlib.h"
|
||||||
|
|
||||||
|
/* Trie tree for prefix storing and manipulating, with a dictionary in
|
||||||
|
* finite-state transducer for spelling correction.
|
||||||
|
*/
|
||||||
|
class PathTrie {
|
||||||
|
public:
|
||||||
|
PathTrie();
|
||||||
|
~PathTrie();
|
||||||
|
|
||||||
|
// get new prefix after appending new char
|
||||||
|
PathTrie* get_path_trie(int new_char, bool reset = true);
|
||||||
|
|
||||||
|
// get the prefix in index from root to current node
|
||||||
|
PathTrie* get_path_vec(std::vector<int>& output);
|
||||||
|
|
||||||
|
// get the prefix in index from some stop node to current nodel
|
||||||
|
PathTrie* get_path_vec(std::vector<int>& output,
|
||||||
|
int stop,
|
||||||
|
size_t max_steps = std::numeric_limits<size_t>::max());
|
||||||
|
|
||||||
|
// update log probs
|
||||||
|
void iterate_to_vec(std::vector<PathTrie*>& output);
|
||||||
|
|
||||||
|
// set dictionary for FST
|
||||||
|
void set_dictionary(fst::StdVectorFst* dictionary);
|
||||||
|
|
||||||
|
void set_matcher(std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>>);
|
||||||
|
|
||||||
|
bool is_empty() { return ROOT_ == character; }
|
||||||
|
|
||||||
|
// remove current path from root
|
||||||
|
void remove();
|
||||||
|
|
||||||
|
float log_prob_b_prev;
|
||||||
|
float log_prob_nb_prev;
|
||||||
|
float log_prob_b_cur;
|
||||||
|
float log_prob_nb_cur;
|
||||||
|
float score;
|
||||||
|
float approx_ctc;
|
||||||
|
int character;
|
||||||
|
PathTrie* parent;
|
||||||
|
|
||||||
|
private:
|
||||||
|
int ROOT_;
|
||||||
|
bool exists_;
|
||||||
|
bool has_dictionary_;
|
||||||
|
|
||||||
|
std::vector<std::pair<int, PathTrie*>> children_;
|
||||||
|
|
||||||
|
// pointer to dictionary of FST
|
||||||
|
fst::StdVectorFst* dictionary_;
|
||||||
|
fst::StdVectorFst::StateId dictionary_state_;
|
||||||
|
// true if finding ars in FST
|
||||||
|
std::shared_ptr<fst::SortedMatcher<fst::StdVectorFst>> matcher_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // PATH_TRIE_H
|
@ -0,0 +1,230 @@
|
|||||||
|
#include "scorer.h"
|
||||||
|
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "lm/config.hh"
|
||||||
|
#include "lm/model.hh"
|
||||||
|
#include "lm/state.hh"
|
||||||
|
#include "util/string_piece.hh"
|
||||||
|
#include "util/tokenize_piece.hh"
|
||||||
|
|
||||||
|
#include "decoder_utils.h"
|
||||||
|
|
||||||
|
using namespace lm::ngram;
|
||||||
|
|
||||||
|
Scorer::Scorer(double alpha,
|
||||||
|
double beta,
|
||||||
|
const std::string& lm_path,
|
||||||
|
const std::vector<std::string>& vocab_list) {
|
||||||
|
this->alpha = alpha;
|
||||||
|
this->beta = beta;
|
||||||
|
|
||||||
|
dictionary = nullptr;
|
||||||
|
is_character_based_ = true;
|
||||||
|
language_model_ = nullptr;
|
||||||
|
|
||||||
|
max_order_ = 0;
|
||||||
|
dict_size_ = 0;
|
||||||
|
SPACE_ID_ = -1;
|
||||||
|
|
||||||
|
setup(lm_path, vocab_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
Scorer::~Scorer() {
|
||||||
|
if (language_model_ != nullptr) {
|
||||||
|
delete static_cast<lm::base::Model*>(language_model_);
|
||||||
|
}
|
||||||
|
if (dictionary != nullptr) {
|
||||||
|
delete static_cast<fst::StdVectorFst*>(dictionary);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Scorer::setup(const std::string& lm_path,
|
||||||
|
const std::vector<std::string>& vocab_list) {
|
||||||
|
// load language model
|
||||||
|
load_lm(lm_path);
|
||||||
|
// set char map for scorer
|
||||||
|
set_char_map(vocab_list);
|
||||||
|
// fill the dictionary for FST
|
||||||
|
if (!is_character_based()) {
|
||||||
|
fill_dictionary(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Scorer::load_lm(const std::string& lm_path) {
|
||||||
|
const char* filename = lm_path.c_str();
|
||||||
|
VALID_CHECK_EQ(access(filename, F_OK), 0, "Invalid language model path");
|
||||||
|
|
||||||
|
RetriveStrEnumerateVocab enumerate;
|
||||||
|
lm::ngram::Config config;
|
||||||
|
config.enumerate_vocab = &enumerate;
|
||||||
|
language_model_ = lm::ngram::LoadVirtual(filename, config);
|
||||||
|
max_order_ = static_cast<lm::base::Model*>(language_model_)->Order();
|
||||||
|
vocabulary_ = enumerate.vocabulary;
|
||||||
|
for (size_t i = 0; i < vocabulary_.size(); ++i) {
|
||||||
|
if (is_character_based_ && vocabulary_[i] != UNK_TOKEN &&
|
||||||
|
vocabulary_[i] != START_TOKEN && vocabulary_[i] != END_TOKEN &&
|
||||||
|
get_utf8_str_len(enumerate.vocabulary[i]) > 1) {
|
||||||
|
is_character_based_ = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
double Scorer::get_log_cond_prob(const std::vector<std::string>& words) {
|
||||||
|
lm::base::Model* model = static_cast<lm::base::Model*>(language_model_);
|
||||||
|
double cond_prob;
|
||||||
|
lm::ngram::State state, tmp_state, out_state;
|
||||||
|
// avoid to inserting <s> in begin
|
||||||
|
model->NullContextWrite(&state);
|
||||||
|
for (size_t i = 0; i < words.size(); ++i) {
|
||||||
|
lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]);
|
||||||
|
// encounter OOV
|
||||||
|
if (word_index == 0) {
|
||||||
|
return OOV_SCORE;
|
||||||
|
}
|
||||||
|
cond_prob = model->BaseScore(&state, word_index, &out_state);
|
||||||
|
tmp_state = state;
|
||||||
|
state = out_state;
|
||||||
|
out_state = tmp_state;
|
||||||
|
}
|
||||||
|
// return log10 prob
|
||||||
|
return cond_prob;
|
||||||
|
}
|
||||||
|
|
||||||
|
double Scorer::get_sent_log_prob(const std::vector<std::string>& words) {
|
||||||
|
std::vector<std::string> sentence;
|
||||||
|
if (words.size() == 0) {
|
||||||
|
for (size_t i = 0; i < max_order_; ++i) {
|
||||||
|
sentence.push_back(START_TOKEN);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t i = 0; i < max_order_ - 1; ++i) {
|
||||||
|
sentence.push_back(START_TOKEN);
|
||||||
|
}
|
||||||
|
sentence.insert(sentence.end(), words.begin(), words.end());
|
||||||
|
}
|
||||||
|
sentence.push_back(END_TOKEN);
|
||||||
|
return get_log_prob(sentence);
|
||||||
|
}
|
||||||
|
|
||||||
|
double Scorer::get_log_prob(const std::vector<std::string>& words) {
|
||||||
|
assert(words.size() > max_order_);
|
||||||
|
double score = 0.0;
|
||||||
|
for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) {
|
||||||
|
std::vector<std::string> ngram(words.begin() + i,
|
||||||
|
words.begin() + i + max_order_);
|
||||||
|
score += get_log_cond_prob(ngram);
|
||||||
|
}
|
||||||
|
return score;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Scorer::reset_params(float alpha, float beta) {
|
||||||
|
this->alpha = alpha;
|
||||||
|
this->beta = beta;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Scorer::vec2str(const std::vector<int>& input) {
|
||||||
|
std::string word;
|
||||||
|
for (auto ind : input) {
|
||||||
|
word += char_list_[ind];
|
||||||
|
}
|
||||||
|
return word;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> Scorer::split_labels(const std::vector<int>& labels) {
|
||||||
|
if (labels.empty()) return {};
|
||||||
|
|
||||||
|
std::string s = vec2str(labels);
|
||||||
|
std::vector<std::string> words;
|
||||||
|
if (is_character_based_) {
|
||||||
|
words = split_utf8_str(s);
|
||||||
|
} else {
|
||||||
|
words = split_str(s, " ");
|
||||||
|
}
|
||||||
|
return words;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Scorer::set_char_map(const std::vector<std::string>& char_list) {
|
||||||
|
char_list_ = char_list;
|
||||||
|
char_map_.clear();
|
||||||
|
|
||||||
|
// Set the char map for the FST for spelling correction
|
||||||
|
for (size_t i = 0; i < char_list_.size(); i++) {
|
||||||
|
if (char_list_[i] == " ") {
|
||||||
|
SPACE_ID_ = i;
|
||||||
|
}
|
||||||
|
// The initial state of FST is state 0, hence the index of chars in
|
||||||
|
// the FST should start from 1 to avoid the conflict with the initial
|
||||||
|
// state, otherwise wrong decoding results would be given.
|
||||||
|
char_map_[char_list_[i]] = i + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> Scorer::make_ngram(PathTrie* prefix) {
|
||||||
|
std::vector<std::string> ngram;
|
||||||
|
PathTrie* current_node = prefix;
|
||||||
|
PathTrie* new_node = nullptr;
|
||||||
|
|
||||||
|
for (int order = 0; order < max_order_; order++) {
|
||||||
|
std::vector<int> prefix_vec;
|
||||||
|
|
||||||
|
if (is_character_based_) {
|
||||||
|
new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_, 1);
|
||||||
|
current_node = new_node;
|
||||||
|
} else {
|
||||||
|
new_node = current_node->get_path_vec(prefix_vec, SPACE_ID_);
|
||||||
|
current_node = new_node->parent; // Skipping spaces
|
||||||
|
}
|
||||||
|
|
||||||
|
// reconstruct word
|
||||||
|
std::string word = vec2str(prefix_vec);
|
||||||
|
ngram.push_back(word);
|
||||||
|
|
||||||
|
if (new_node->character == -1) {
|
||||||
|
// No more spaces, but still need order
|
||||||
|
for (int i = 0; i < max_order_ - order - 1; i++) {
|
||||||
|
ngram.push_back(START_TOKEN);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::reverse(ngram.begin(), ngram.end());
|
||||||
|
return ngram;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Scorer::fill_dictionary(bool add_space) {
|
||||||
|
fst::StdVectorFst dictionary;
|
||||||
|
// For each unigram convert to ints and put in trie
|
||||||
|
int dict_size = 0;
|
||||||
|
for (const auto& word : vocabulary_) {
|
||||||
|
bool added = add_word_to_dictionary(
|
||||||
|
word, char_map_, add_space, SPACE_ID_ + 1, &dictionary);
|
||||||
|
dict_size += added ? 1 : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
dict_size_ = dict_size;
|
||||||
|
|
||||||
|
/* Simplify FST
|
||||||
|
|
||||||
|
* This gets rid of "epsilon" transitions in the FST.
|
||||||
|
* These are transitions that don't require a string input to be taken.
|
||||||
|
* Getting rid of them is necessary to make the FST determinisitc, but
|
||||||
|
* can greatly increase the size of the FST
|
||||||
|
*/
|
||||||
|
fst::RmEpsilon(&dictionary);
|
||||||
|
fst::StdVectorFst* new_dict = new fst::StdVectorFst;
|
||||||
|
|
||||||
|
/* This makes the FST deterministic, meaning for any string input there's
|
||||||
|
* only one possible state the FST could be in. It is assumed our
|
||||||
|
* dictionary is deterministic when using it.
|
||||||
|
* (lest we'd have to check for multiple transitions at each state)
|
||||||
|
*/
|
||||||
|
fst::Determinize(dictionary, new_dict);
|
||||||
|
|
||||||
|
/* Finds the simplest equivalent fst. This is unnecessary but decreases
|
||||||
|
* memory usage of the dictionary
|
||||||
|
*/
|
||||||
|
fst::Minimize(new_dict);
|
||||||
|
this->dictionary = new_dict;
|
||||||
|
}
|
@ -0,0 +1,112 @@
|
|||||||
|
#ifndef SCORER_H_
|
||||||
|
#define SCORER_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "lm/enumerate_vocab.hh"
|
||||||
|
#include "lm/virtual_interface.hh"
|
||||||
|
#include "lm/word_index.hh"
|
||||||
|
#include "util/string_piece.hh"
|
||||||
|
|
||||||
|
#include "path_trie.h"
|
||||||
|
|
||||||
|
const double OOV_SCORE = -1000.0;
|
||||||
|
const std::string START_TOKEN = "<s>";
|
||||||
|
const std::string UNK_TOKEN = "<unk>";
|
||||||
|
const std::string END_TOKEN = "</s>";
|
||||||
|
|
||||||
|
// Implement a callback to retrive the dictionary of language model.
|
||||||
|
class RetriveStrEnumerateVocab : public lm::EnumerateVocab {
|
||||||
|
public:
|
||||||
|
RetriveStrEnumerateVocab() {}
|
||||||
|
|
||||||
|
void Add(lm::WordIndex index, const StringPiece &str) {
|
||||||
|
vocabulary.push_back(std::string(str.data(), str.length()));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> vocabulary;
|
||||||
|
};
|
||||||
|
|
||||||
|
/* External scorer to query score for n-gram or sentence, including language
|
||||||
|
* model scoring and word insertion.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* Scorer scorer(alpha, beta, "path_of_language_model");
|
||||||
|
* scorer.get_log_cond_prob({ "WORD1", "WORD2", "WORD3" });
|
||||||
|
* scorer.get_sent_log_prob({ "WORD1", "WORD2", "WORD3" });
|
||||||
|
*/
|
||||||
|
class Scorer {
|
||||||
|
public:
|
||||||
|
Scorer(double alpha,
|
||||||
|
double beta,
|
||||||
|
const std::string &lm_path,
|
||||||
|
const std::vector<std::string> &vocabulary);
|
||||||
|
~Scorer();
|
||||||
|
|
||||||
|
double get_log_cond_prob(const std::vector<std::string> &words);
|
||||||
|
|
||||||
|
double get_sent_log_prob(const std::vector<std::string> &words);
|
||||||
|
|
||||||
|
// return the max order
|
||||||
|
size_t get_max_order() const { return max_order_; }
|
||||||
|
|
||||||
|
// return the dictionary size of language model
|
||||||
|
size_t get_dict_size() const { return dict_size_; }
|
||||||
|
|
||||||
|
// retrun true if the language model is character based
|
||||||
|
bool is_character_based() const { return is_character_based_; }
|
||||||
|
|
||||||
|
// reset params alpha & beta
|
||||||
|
void reset_params(float alpha, float beta);
|
||||||
|
|
||||||
|
// make ngram for a given prefix
|
||||||
|
std::vector<std::string> make_ngram(PathTrie *prefix);
|
||||||
|
|
||||||
|
// trransform the labels in index to the vector of words (word based lm) or
|
||||||
|
// the vector of characters (character based lm)
|
||||||
|
std::vector<std::string> split_labels(const std::vector<int> &labels);
|
||||||
|
|
||||||
|
// language model weight
|
||||||
|
double alpha;
|
||||||
|
// word insertion weight
|
||||||
|
double beta;
|
||||||
|
|
||||||
|
// pointer to the dictionary of FST
|
||||||
|
void *dictionary;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// necessary setup: load language model, set char map, fill FST's dictionary
|
||||||
|
void setup(const std::string &lm_path,
|
||||||
|
const std::vector<std::string> &vocab_list);
|
||||||
|
|
||||||
|
// load language model from given path
|
||||||
|
void load_lm(const std::string &lm_path);
|
||||||
|
|
||||||
|
// fill dictionary for FST
|
||||||
|
void fill_dictionary(bool add_space);
|
||||||
|
|
||||||
|
// set char map
|
||||||
|
void set_char_map(const std::vector<std::string> &char_list);
|
||||||
|
|
||||||
|
double get_log_prob(const std::vector<std::string> &words);
|
||||||
|
|
||||||
|
// translate the vector in index to string
|
||||||
|
std::string vec2str(const std::vector<int> &input);
|
||||||
|
|
||||||
|
private:
|
||||||
|
void *language_model_;
|
||||||
|
bool is_character_based_;
|
||||||
|
size_t max_order_;
|
||||||
|
size_t dict_size_;
|
||||||
|
|
||||||
|
int SPACE_ID_;
|
||||||
|
std::vector<std::string> char_list_;
|
||||||
|
std::unordered_map<std::string, int> char_map_;
|
||||||
|
|
||||||
|
std::vector<std::string> vocabulary_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // SCORER_H_
|
@ -0,0 +1,119 @@
|
|||||||
|
"""Script to build and install decoder package."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from setuptools import setup, Extension, distutils
|
||||||
|
import glob
|
||||||
|
import platform
|
||||||
|
import os, sys
|
||||||
|
import multiprocessing.pool
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_processes",
|
||||||
|
default=1,
|
||||||
|
type=int,
|
||||||
|
help="Number of cpu processes to build package. (default: %(default)d)")
|
||||||
|
args = parser.parse_known_args()
|
||||||
|
|
||||||
|
# reconstruct sys.argv to pass to setup below
|
||||||
|
sys.argv = [sys.argv[0]] + args[1]
|
||||||
|
|
||||||
|
|
||||||
|
# monkey-patch for parallel compilation
|
||||||
|
# See: https://stackoverflow.com/a/13176803
|
||||||
|
def parallelCCompile(self,
|
||||||
|
sources,
|
||||||
|
output_dir=None,
|
||||||
|
macros=None,
|
||||||
|
include_dirs=None,
|
||||||
|
debug=0,
|
||||||
|
extra_preargs=None,
|
||||||
|
extra_postargs=None,
|
||||||
|
depends=None):
|
||||||
|
# those lines are copied from distutils.ccompiler.CCompiler directly
|
||||||
|
macros, objects, extra_postargs, pp_opts, build = self._setup_compile(
|
||||||
|
output_dir, macros, include_dirs, sources, depends, extra_postargs)
|
||||||
|
cc_args = self._get_cc_args(pp_opts, debug, extra_preargs)
|
||||||
|
|
||||||
|
# parallel code
|
||||||
|
def _single_compile(obj):
|
||||||
|
try:
|
||||||
|
src, ext = build[obj]
|
||||||
|
except KeyError:
|
||||||
|
return
|
||||||
|
self._compile(obj, src, ext, cc_args, extra_postargs, pp_opts)
|
||||||
|
|
||||||
|
# convert to list, imap is evaluated on-demand
|
||||||
|
thread_pool = multiprocessing.pool.ThreadPool(args[0].num_processes)
|
||||||
|
list(thread_pool.imap(_single_compile, objects))
|
||||||
|
return objects
|
||||||
|
|
||||||
|
|
||||||
|
def compile_test(header, library):
|
||||||
|
dummy_path = os.path.join(os.path.dirname(__file__), "dummy")
|
||||||
|
command = "bash -c \"g++ -include " + header \
|
||||||
|
+ " -l" + library + " -x c++ - <<<'int main() {}' -o " \
|
||||||
|
+ dummy_path + " >/dev/null 2>/dev/null && rm " \
|
||||||
|
+ dummy_path + " 2>/dev/null\""
|
||||||
|
return os.system(command) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# hack compile to support parallel compiling
|
||||||
|
distutils.ccompiler.CCompiler.compile = parallelCCompile
|
||||||
|
|
||||||
|
FILES = glob.glob('kenlm/util/*.cc') \
|
||||||
|
+ glob.glob('kenlm/lm/*.cc') \
|
||||||
|
+ glob.glob('kenlm/util/double-conversion/*.cc')
|
||||||
|
|
||||||
|
FILES += glob.glob('openfst-1.6.3/src/lib/*.cc')
|
||||||
|
|
||||||
|
FILES = [
|
||||||
|
fn for fn in FILES
|
||||||
|
if not (fn.endswith('main.cc') or fn.endswith('test.cc') or fn.endswith(
|
||||||
|
'unittest.cc'))
|
||||||
|
]
|
||||||
|
|
||||||
|
LIBS = ['stdc++']
|
||||||
|
if platform.system() != 'Darwin':
|
||||||
|
LIBS.append('rt')
|
||||||
|
|
||||||
|
ARGS = ['-O3', '-DNDEBUG', '-DKENLM_MAX_ORDER=6', '-std=c++11']
|
||||||
|
|
||||||
|
if compile_test('zlib.h', 'z'):
|
||||||
|
ARGS.append('-DHAVE_ZLIB')
|
||||||
|
LIBS.append('z')
|
||||||
|
|
||||||
|
if compile_test('bzlib.h', 'bz2'):
|
||||||
|
ARGS.append('-DHAVE_BZLIB')
|
||||||
|
LIBS.append('bz2')
|
||||||
|
|
||||||
|
if compile_test('lzma.h', 'lzma'):
|
||||||
|
ARGS.append('-DHAVE_XZLIB')
|
||||||
|
LIBS.append('lzma')
|
||||||
|
|
||||||
|
os.system('swig -python -c++ ./decoders.i')
|
||||||
|
|
||||||
|
decoders_module = [
|
||||||
|
Extension(
|
||||||
|
name='_swig_decoders',
|
||||||
|
sources=FILES + glob.glob('*.cxx') + glob.glob('*.cpp'),
|
||||||
|
language='c++',
|
||||||
|
include_dirs=[
|
||||||
|
'.',
|
||||||
|
'kenlm',
|
||||||
|
'openfst-1.6.3/src/include',
|
||||||
|
'ThreadPool',
|
||||||
|
],
|
||||||
|
libraries=LIBS,
|
||||||
|
extra_compile_args=ARGS)
|
||||||
|
]
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='swig_decoders',
|
||||||
|
version='1.1',
|
||||||
|
description="""CTC decoders""",
|
||||||
|
ext_modules=decoders_module,
|
||||||
|
py_modules=['swig_decoders'], )
|
@ -0,0 +1,21 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
if [ ! -d kenlm ]; then
|
||||||
|
git clone https://github.com/luotao1/kenlm.git
|
||||||
|
echo -e "\n"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -d openfst-1.6.3 ]; then
|
||||||
|
echo "Download and extract openfst ..."
|
||||||
|
wget http://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.6.3.tar.gz
|
||||||
|
tar -xzvf openfst-1.6.3.tar.gz
|
||||||
|
echo -e "\n"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -d ThreadPool ]; then
|
||||||
|
git clone https://github.com/progschj/ThreadPool.git
|
||||||
|
echo -e "\n"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Install decoders ..."
|
||||||
|
python setup.py install --num_processes 4
|
@ -0,0 +1,124 @@
|
|||||||
|
"""Wrapper for various CTC decoders in SWIG."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import swig_decoders
|
||||||
|
|
||||||
|
|
||||||
|
class Scorer(swig_decoders.Scorer):
|
||||||
|
"""Wrapper for Scorer.
|
||||||
|
|
||||||
|
:param alpha: Parameter associated with language model. Don't use
|
||||||
|
language model when alpha = 0.
|
||||||
|
:type alpha: float
|
||||||
|
:param beta: Parameter associated with word count. Don't use word
|
||||||
|
count when beta = 0.
|
||||||
|
:type beta: float
|
||||||
|
:model_path: Path to load language model.
|
||||||
|
:type model_path: basestring
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, alpha, beta, model_path, vocabulary):
|
||||||
|
swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary)
|
||||||
|
|
||||||
|
|
||||||
|
def ctc_greedy_decoder(probs_seq, vocabulary):
|
||||||
|
"""Wrapper for ctc best path decoder in swig.
|
||||||
|
|
||||||
|
:param probs_seq: 2-D list of probability distributions over each time
|
||||||
|
step, with each element being a list of normalized
|
||||||
|
probabilities over vocabulary and blank.
|
||||||
|
:type probs_seq: 2-D list
|
||||||
|
:param vocabulary: Vocabulary list.
|
||||||
|
:type vocabulary: list
|
||||||
|
:return: Decoding result string.
|
||||||
|
:rtype: basestring
|
||||||
|
"""
|
||||||
|
result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary)
|
||||||
|
return result.decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
def ctc_beam_search_decoder(probs_seq,
|
||||||
|
vocabulary,
|
||||||
|
beam_size,
|
||||||
|
cutoff_prob=1.0,
|
||||||
|
cutoff_top_n=40,
|
||||||
|
ext_scoring_func=None):
|
||||||
|
"""Wrapper for the CTC Beam Search Decoder.
|
||||||
|
|
||||||
|
:param probs_seq: 2-D list of probability distributions over each time
|
||||||
|
step, with each element being a list of normalized
|
||||||
|
probabilities over vocabulary and blank.
|
||||||
|
:type probs_seq: 2-D list
|
||||||
|
:param vocabulary: Vocabulary list.
|
||||||
|
:type vocabulary: list
|
||||||
|
:param beam_size: Width for beam search.
|
||||||
|
:type beam_size: int
|
||||||
|
:param cutoff_prob: Cutoff probability in pruning,
|
||||||
|
default 1.0, no pruning.
|
||||||
|
:type cutoff_prob: float
|
||||||
|
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
|
||||||
|
characters with highest probs in vocabulary will be
|
||||||
|
used in beam search, default 40.
|
||||||
|
:type cutoff_top_n: int
|
||||||
|
:param ext_scoring_func: External scoring function for
|
||||||
|
partially decoded sentence, e.g. word count
|
||||||
|
or language model.
|
||||||
|
:type external_scoring_func: callable
|
||||||
|
:return: List of tuples of log probability and sentence as decoding
|
||||||
|
results, in descending order of the probability.
|
||||||
|
:rtype: list
|
||||||
|
"""
|
||||||
|
beam_results = swig_decoders.ctc_beam_search_decoder(
|
||||||
|
probs_seq.tolist(), vocabulary, beam_size, cutoff_prob, cutoff_top_n,
|
||||||
|
ext_scoring_func)
|
||||||
|
beam_results = [(res[0], res[1].decode('utf-8')) for res in beam_results]
|
||||||
|
return beam_results
|
||||||
|
|
||||||
|
|
||||||
|
def ctc_beam_search_decoder_batch(probs_split,
|
||||||
|
vocabulary,
|
||||||
|
beam_size,
|
||||||
|
num_processes,
|
||||||
|
cutoff_prob=1.0,
|
||||||
|
cutoff_top_n=40,
|
||||||
|
ext_scoring_func=None):
|
||||||
|
"""Wrapper for the batched CTC beam search decoder.
|
||||||
|
|
||||||
|
:param probs_seq: 3-D list with each element as an instance of 2-D list
|
||||||
|
of probabilities used by ctc_beam_search_decoder().
|
||||||
|
:type probs_seq: 3-D list
|
||||||
|
:param vocabulary: Vocabulary list.
|
||||||
|
:type vocabulary: list
|
||||||
|
:param beam_size: Width for beam search.
|
||||||
|
:type beam_size: int
|
||||||
|
:param num_processes: Number of parallel processes.
|
||||||
|
:type num_processes: int
|
||||||
|
:param cutoff_prob: Cutoff probability in vocabulary pruning,
|
||||||
|
default 1.0, no pruning.
|
||||||
|
:type cutoff_prob: float
|
||||||
|
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
|
||||||
|
characters with highest probs in vocabulary will be
|
||||||
|
used in beam search, default 40.
|
||||||
|
:type cutoff_top_n: int
|
||||||
|
:param num_processes: Number of parallel processes.
|
||||||
|
:type num_processes: int
|
||||||
|
:param ext_scoring_func: External scoring function for
|
||||||
|
partially decoded sentence, e.g. word count
|
||||||
|
or language model.
|
||||||
|
:type external_scoring_function: callable
|
||||||
|
:return: List of tuples of log probability and sentence as decoding
|
||||||
|
results, in descending order of the probability.
|
||||||
|
:rtype: list
|
||||||
|
"""
|
||||||
|
probs_split = [probs_seq.tolist() for probs_seq in probs_split]
|
||||||
|
|
||||||
|
batch_beam_results = swig_decoders.ctc_beam_search_decoder_batch(
|
||||||
|
probs_split, vocabulary, beam_size, num_processes, cutoff_prob,
|
||||||
|
cutoff_top_n, ext_scoring_func)
|
||||||
|
batch_beam_results = [
|
||||||
|
[(res[0], res[1].decode("utf-8")) for res in beam_results]
|
||||||
|
for beam_results in batch_beam_results
|
||||||
|
]
|
||||||
|
return batch_beam_results
|
@ -0,0 +1,90 @@
|
|||||||
|
"""Test decoders."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from decoders import decoders_deprecated as decoder
|
||||||
|
|
||||||
|
|
||||||
|
class TestDecoders(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.vocab_list = ["\'", ' ', 'a', 'b', 'c', 'd']
|
||||||
|
self.beam_size = 20
|
||||||
|
self.probs_seq1 = [[
|
||||||
|
0.06390443, 0.21124858, 0.27323887, 0.06870235, 0.0361254,
|
||||||
|
0.18184413, 0.16493624
|
||||||
|
], [
|
||||||
|
0.03309247, 0.22866108, 0.24390638, 0.09699597, 0.31895462,
|
||||||
|
0.0094893, 0.06890021
|
||||||
|
], [
|
||||||
|
0.218104, 0.19992557, 0.18245131, 0.08503348, 0.14903535,
|
||||||
|
0.08424043, 0.08120984
|
||||||
|
], [
|
||||||
|
0.12094152, 0.19162472, 0.01473646, 0.28045061, 0.24246305,
|
||||||
|
0.05206269, 0.09772094
|
||||||
|
], [
|
||||||
|
0.1333387, 0.00550838, 0.00301669, 0.21745861, 0.20803985,
|
||||||
|
0.41317442, 0.01946335
|
||||||
|
], [
|
||||||
|
0.16468227, 0.1980699, 0.1906545, 0.18963251, 0.19860937,
|
||||||
|
0.04377724, 0.01457421
|
||||||
|
]]
|
||||||
|
self.probs_seq2 = [[
|
||||||
|
0.08034842, 0.22671944, 0.05799633, 0.36814645, 0.11307441,
|
||||||
|
0.04468023, 0.10903471
|
||||||
|
], [
|
||||||
|
0.09742457, 0.12959763, 0.09435383, 0.21889204, 0.15113123,
|
||||||
|
0.10219457, 0.20640612
|
||||||
|
], [
|
||||||
|
0.45033529, 0.09091417, 0.15333208, 0.07939558, 0.08649316,
|
||||||
|
0.12298585, 0.01654384
|
||||||
|
], [
|
||||||
|
0.02512238, 0.22079203, 0.19664364, 0.11906379, 0.07816055,
|
||||||
|
0.22538587, 0.13483174
|
||||||
|
], [
|
||||||
|
0.17928453, 0.06065261, 0.41153005, 0.1172041, 0.11880313,
|
||||||
|
0.07113197, 0.04139363
|
||||||
|
], [
|
||||||
|
0.15882358, 0.1235788, 0.23376776, 0.20510435, 0.00279306,
|
||||||
|
0.05294827, 0.22298418
|
||||||
|
]]
|
||||||
|
self.greedy_result = ["ac'bdc", "b'da"]
|
||||||
|
self.beam_search_result = ['acdc', "b'a"]
|
||||||
|
|
||||||
|
def test_greedy_decoder_1(self):
|
||||||
|
bst_result = decoder.ctc_greedy_decoder(self.probs_seq1,
|
||||||
|
self.vocab_list)
|
||||||
|
self.assertEqual(bst_result, self.greedy_result[0])
|
||||||
|
|
||||||
|
def test_greedy_decoder_2(self):
|
||||||
|
bst_result = decoder.ctc_greedy_decoder(self.probs_seq2,
|
||||||
|
self.vocab_list)
|
||||||
|
self.assertEqual(bst_result, self.greedy_result[1])
|
||||||
|
|
||||||
|
def test_beam_search_decoder_1(self):
|
||||||
|
beam_result = decoder.ctc_beam_search_decoder(
|
||||||
|
probs_seq=self.probs_seq1,
|
||||||
|
beam_size=self.beam_size,
|
||||||
|
vocabulary=self.vocab_list)
|
||||||
|
self.assertEqual(beam_result[0][1], self.beam_search_result[0])
|
||||||
|
|
||||||
|
def test_beam_search_decoder_2(self):
|
||||||
|
beam_result = decoder.ctc_beam_search_decoder(
|
||||||
|
probs_seq=self.probs_seq2,
|
||||||
|
beam_size=self.beam_size,
|
||||||
|
vocabulary=self.vocab_list)
|
||||||
|
self.assertEqual(beam_result[0][1], self.beam_search_result[1])
|
||||||
|
|
||||||
|
def test_beam_search_decoder_batch(self):
|
||||||
|
beam_results = decoder.ctc_beam_search_decoder_batch(
|
||||||
|
probs_split=[self.probs_seq1, self.probs_seq2],
|
||||||
|
beam_size=self.beam_size,
|
||||||
|
vocabulary=self.vocab_list,
|
||||||
|
num_processes=24)
|
||||||
|
self.assertEqual(beam_results[0][0][1], self.beam_search_result[0])
|
||||||
|
self.assertEqual(beam_results[1][0][1], self.beam_search_result[1])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
@ -0,0 +1,19 @@
|
|||||||
|
"""Set up paths for DS2"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os.path
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def add_path(path):
|
||||||
|
if path not in sys.path:
|
||||||
|
sys.path.insert(0, path)
|
||||||
|
|
||||||
|
|
||||||
|
this_dir = os.path.dirname(__file__)
|
||||||
|
|
||||||
|
# Add project path to PYTHONPATH
|
||||||
|
proj_path = os.path.join(this_dir, '..')
|
||||||
|
add_path(proj_path)
|
@ -0,0 +1,94 @@
|
|||||||
|
"""Client-end for the ASR demo."""
|
||||||
|
from pynput import keyboard
|
||||||
|
import struct
|
||||||
|
import socket
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
import pyaudio
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--host_ip",
|
||||||
|
default="localhost",
|
||||||
|
type=str,
|
||||||
|
help="Server IP address. (default: %(default)s)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--host_port",
|
||||||
|
default=8086,
|
||||||
|
type=int,
|
||||||
|
help="Server Port. (default: %(default)s)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
is_recording = False
|
||||||
|
enable_trigger_record = True
|
||||||
|
|
||||||
|
|
||||||
|
def on_press(key):
|
||||||
|
"""On-press keyboard callback function."""
|
||||||
|
global is_recording, enable_trigger_record
|
||||||
|
if key == keyboard.Key.space:
|
||||||
|
if (not is_recording) and enable_trigger_record:
|
||||||
|
sys.stdout.write("Start Recording ... ")
|
||||||
|
sys.stdout.flush()
|
||||||
|
is_recording = True
|
||||||
|
|
||||||
|
|
||||||
|
def on_release(key):
|
||||||
|
"""On-release keyboard callback function."""
|
||||||
|
global is_recording, enable_trigger_record
|
||||||
|
if key == keyboard.Key.esc:
|
||||||
|
return False
|
||||||
|
elif key == keyboard.Key.space:
|
||||||
|
if is_recording == True:
|
||||||
|
is_recording = False
|
||||||
|
|
||||||
|
|
||||||
|
data_list = []
|
||||||
|
|
||||||
|
|
||||||
|
def callback(in_data, frame_count, time_info, status):
|
||||||
|
"""Audio recorder's stream callback function."""
|
||||||
|
global data_list, is_recording, enable_trigger_record
|
||||||
|
if is_recording:
|
||||||
|
data_list.append(in_data)
|
||||||
|
enable_trigger_record = False
|
||||||
|
elif len(data_list) > 0:
|
||||||
|
# Connect to server and send data
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
sock.connect((args.host_ip, args.host_port))
|
||||||
|
sent = ''.join(data_list)
|
||||||
|
sock.sendall(struct.pack('>i', len(sent)) + sent)
|
||||||
|
print('Speech[length=%d] Sent.' % len(sent))
|
||||||
|
# Receive data from the server and shut down
|
||||||
|
received = sock.recv(1024)
|
||||||
|
print "Recognition Results: {}".format(received)
|
||||||
|
sock.close()
|
||||||
|
data_list = []
|
||||||
|
enable_trigger_record = True
|
||||||
|
return (in_data, pyaudio.paContinue)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# prepare audio recorder
|
||||||
|
p = pyaudio.PyAudio()
|
||||||
|
stream = p.open(
|
||||||
|
format=pyaudio.paInt32,
|
||||||
|
channels=1,
|
||||||
|
rate=16000,
|
||||||
|
input=True,
|
||||||
|
stream_callback=callback)
|
||||||
|
stream.start_stream()
|
||||||
|
|
||||||
|
# prepare keyboard listener
|
||||||
|
with keyboard.Listener(
|
||||||
|
on_press=on_press, on_release=on_release) as listener:
|
||||||
|
listener.join()
|
||||||
|
|
||||||
|
# close up
|
||||||
|
stream.stop_stream()
|
||||||
|
stream.close()
|
||||||
|
p.terminate()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -0,0 +1,215 @@
|
|||||||
|
"""Server-end for the ASR demo."""
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
import argparse
|
||||||
|
import functools
|
||||||
|
from time import gmtime, strftime
|
||||||
|
import SocketServer
|
||||||
|
import struct
|
||||||
|
import wave
|
||||||
|
import paddle.v2 as paddle
|
||||||
|
import _init_paths
|
||||||
|
from data_utils.data import DataGenerator
|
||||||
|
from model_utils.model import DeepSpeech2Model
|
||||||
|
from data_utils.utility import read_manifest
|
||||||
|
from utils.utility import add_arguments, print_arguments
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
add_arg = functools.partial(add_arguments, argparser=parser)
|
||||||
|
# yapf: disable
|
||||||
|
add_arg('host_port', int, 8086, "Server's IP port.")
|
||||||
|
add_arg('beam_size', int, 500, "Beam search width.")
|
||||||
|
add_arg('num_conv_layers', int, 2, "# of convolution layers.")
|
||||||
|
add_arg('num_rnn_layers', int, 3, "# of recurrent layers.")
|
||||||
|
add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.")
|
||||||
|
add_arg('alpha', float, 2.5, "Coef of LM for beam search.")
|
||||||
|
add_arg('beta', float, 0.3, "Coef of WC for beam search.")
|
||||||
|
add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.")
|
||||||
|
add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.")
|
||||||
|
add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.")
|
||||||
|
add_arg('use_gpu', bool, True, "Use GPU or not.")
|
||||||
|
add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across "
|
||||||
|
"bi-directional RNNs. Not for GRU.")
|
||||||
|
add_arg('host_ip', str,
|
||||||
|
'localhost',
|
||||||
|
"Server's IP address.")
|
||||||
|
add_arg('speech_save_dir', str,
|
||||||
|
'demo_cache',
|
||||||
|
"Directory to save demo audios.")
|
||||||
|
add_arg('warmup_manifest', str,
|
||||||
|
'data/librispeech/manifest.test-clean',
|
||||||
|
"Filepath of manifest to warm up.")
|
||||||
|
add_arg('mean_std_path', str,
|
||||||
|
'data/librispeech/mean_std.npz',
|
||||||
|
"Filepath of normalizer's mean & std.")
|
||||||
|
add_arg('vocab_path', str,
|
||||||
|
'data/librispeech/eng_vocab.txt',
|
||||||
|
"Filepath of vocabulary.")
|
||||||
|
add_arg('model_path', str,
|
||||||
|
'./checkpoints/libri/params.latest.tar.gz',
|
||||||
|
"If None, the training starts from scratch, "
|
||||||
|
"otherwise, it resumes from the pre-trained model.")
|
||||||
|
add_arg('lang_model_path', str,
|
||||||
|
'lm/data/common_crawl_00.prune01111.trie.klm',
|
||||||
|
"Filepath for language model.")
|
||||||
|
add_arg('decoding_method', str,
|
||||||
|
'ctc_beam_search',
|
||||||
|
"Decoding method. Options: ctc_beam_search, ctc_greedy",
|
||||||
|
choices = ['ctc_beam_search', 'ctc_greedy'])
|
||||||
|
add_arg('specgram_type', str,
|
||||||
|
'linear',
|
||||||
|
"Audio feature type. Options: linear, mfcc.",
|
||||||
|
choices=['linear', 'mfcc'])
|
||||||
|
# yapf: disable
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
class AsrTCPServer(SocketServer.TCPServer):
|
||||||
|
"""The ASR TCP Server."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
server_address,
|
||||||
|
RequestHandlerClass,
|
||||||
|
speech_save_dir,
|
||||||
|
audio_process_handler,
|
||||||
|
bind_and_activate=True):
|
||||||
|
self.speech_save_dir = speech_save_dir
|
||||||
|
self.audio_process_handler = audio_process_handler
|
||||||
|
SocketServer.TCPServer.__init__(
|
||||||
|
self, server_address, RequestHandlerClass, bind_and_activate=True)
|
||||||
|
|
||||||
|
|
||||||
|
class AsrRequestHandler(SocketServer.BaseRequestHandler):
|
||||||
|
"""The ASR request handler."""
|
||||||
|
|
||||||
|
def handle(self):
|
||||||
|
# receive data through TCP socket
|
||||||
|
chunk = self.request.recv(1024)
|
||||||
|
target_len = struct.unpack('>i', chunk[:4])[0]
|
||||||
|
data = chunk[4:]
|
||||||
|
while len(data) < target_len:
|
||||||
|
chunk = self.request.recv(1024)
|
||||||
|
data += chunk
|
||||||
|
# write to file
|
||||||
|
filename = self._write_to_file(data)
|
||||||
|
|
||||||
|
print("Received utterance[length=%d] from %s, saved to %s." %
|
||||||
|
(len(data), self.client_address[0], filename))
|
||||||
|
start_time = time.time()
|
||||||
|
transcript = self.server.audio_process_handler(filename)
|
||||||
|
finish_time = time.time()
|
||||||
|
print("Response Time: %f, Transcript: %s" %
|
||||||
|
(finish_time - start_time, transcript))
|
||||||
|
self.request.sendall(transcript.encode('utf-8'))
|
||||||
|
|
||||||
|
def _write_to_file(self, data):
|
||||||
|
# prepare save dir and filename
|
||||||
|
if not os.path.exists(self.server.speech_save_dir):
|
||||||
|
os.mkdir(self.server.speech_save_dir)
|
||||||
|
timestamp = strftime("%Y%m%d%H%M%S", gmtime())
|
||||||
|
out_filename = os.path.join(
|
||||||
|
self.server.speech_save_dir,
|
||||||
|
timestamp + "_" + self.client_address[0] + ".wav")
|
||||||
|
# write to wav file
|
||||||
|
file = wave.open(out_filename, 'wb')
|
||||||
|
file.setnchannels(1)
|
||||||
|
file.setsampwidth(4)
|
||||||
|
file.setframerate(16000)
|
||||||
|
file.writeframes(data)
|
||||||
|
file.close()
|
||||||
|
return out_filename
|
||||||
|
|
||||||
|
|
||||||
|
def warm_up_test(audio_process_handler,
|
||||||
|
manifest_path,
|
||||||
|
num_test_cases,
|
||||||
|
random_seed=0):
|
||||||
|
"""Warming-up test."""
|
||||||
|
manifest = read_manifest(manifest_path)
|
||||||
|
rng = random.Random(random_seed)
|
||||||
|
samples = rng.sample(manifest, num_test_cases)
|
||||||
|
for idx, sample in enumerate(samples):
|
||||||
|
print("Warm-up Test Case %d: %s", idx, sample['audio_filepath'])
|
||||||
|
start_time = time.time()
|
||||||
|
transcript = audio_process_handler(sample['audio_filepath'])
|
||||||
|
finish_time = time.time()
|
||||||
|
print("Response Time: %f, Transcript: %s" %
|
||||||
|
(finish_time - start_time, transcript))
|
||||||
|
|
||||||
|
|
||||||
|
def start_server():
|
||||||
|
"""Start the ASR server"""
|
||||||
|
# prepare data generator
|
||||||
|
data_generator = DataGenerator(
|
||||||
|
vocab_filepath=args.vocab_path,
|
||||||
|
mean_std_filepath=args.mean_std_path,
|
||||||
|
augmentation_config='{}',
|
||||||
|
specgram_type=args.specgram_type,
|
||||||
|
num_threads=1,
|
||||||
|
keep_transcription_text=True)
|
||||||
|
# prepare ASR model
|
||||||
|
ds2_model = DeepSpeech2Model(
|
||||||
|
vocab_size=data_generator.vocab_size,
|
||||||
|
num_conv_layers=args.num_conv_layers,
|
||||||
|
num_rnn_layers=args.num_rnn_layers,
|
||||||
|
rnn_layer_size=args.rnn_layer_size,
|
||||||
|
use_gru=args.use_gru,
|
||||||
|
pretrained_model_path=args.model_path,
|
||||||
|
share_rnn_weights=args.share_rnn_weights)
|
||||||
|
|
||||||
|
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
|
||||||
|
|
||||||
|
if args.decoding_method == "ctc_beam_search":
|
||||||
|
ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path,
|
||||||
|
vocab_list)
|
||||||
|
# prepare ASR inference handler
|
||||||
|
def file_to_transcript(filename):
|
||||||
|
feature = data_generator.process_utterance(filename, "")
|
||||||
|
probs_split = ds2_model.infer_batch_probs(
|
||||||
|
infer_data=[feature],
|
||||||
|
feeding_dict=data_generator.feeding)
|
||||||
|
|
||||||
|
if args.decoding_method == "ctc_greedy":
|
||||||
|
result_transcript = ds2_model.decode_batch_greedy(
|
||||||
|
probs_split=probs_split,
|
||||||
|
vocab_list=vocab_list)
|
||||||
|
else:
|
||||||
|
result_transcript = ds2_model.decode_batch_beam_search(
|
||||||
|
probs_split=probs_split,
|
||||||
|
beam_alpha=args.alpha,
|
||||||
|
beam_beta=args.beta,
|
||||||
|
beam_size=args.beam_size,
|
||||||
|
cutoff_prob=args.cutoff_prob,
|
||||||
|
cutoff_top_n=args.cutoff_top_n,
|
||||||
|
vocab_list=vocab_list,
|
||||||
|
num_processes=1)
|
||||||
|
return result_transcript[0]
|
||||||
|
|
||||||
|
# warming up with utterrances sampled from Librispeech
|
||||||
|
print('-----------------------------------------------------------')
|
||||||
|
print('Warming up ...')
|
||||||
|
warm_up_test(
|
||||||
|
audio_process_handler=file_to_transcript,
|
||||||
|
manifest_path=args.warmup_manifest,
|
||||||
|
num_test_cases=3)
|
||||||
|
print('-----------------------------------------------------------')
|
||||||
|
|
||||||
|
# start the server
|
||||||
|
server = AsrTCPServer(
|
||||||
|
server_address=(args.host_ip, args.host_port),
|
||||||
|
RequestHandlerClass=AsrRequestHandler,
|
||||||
|
speech_save_dir=args.speech_save_dir,
|
||||||
|
audio_process_handler=file_to_transcript)
|
||||||
|
print("ASR Server Started.")
|
||||||
|
server.serve_forever()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print_arguments(args)
|
||||||
|
paddle.init(use_gpu=args.use_gpu, trainer_count=1)
|
||||||
|
start_server()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
After Width: | Height: | Size: 153 KiB |
After Width: | Height: | Size: 108 KiB |
@ -0,0 +1,42 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
cd ../.. > /dev/null
|
||||||
|
|
||||||
|
# download data, generate manifests
|
||||||
|
PYTHONPATH=.:$PYTHONPATH python data/aishell/aishell.py \
|
||||||
|
--manifest_prefix='data/aishell/manifest' \
|
||||||
|
--target_dir='~/.cache/paddle/dataset/speech/Aishell'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Prepare Aishell failed. Terminated."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
# build vocabulary
|
||||||
|
python tools/build_vocab.py \
|
||||||
|
--count_threshold=0 \
|
||||||
|
--vocab_path='data/aishell/vocab.txt' \
|
||||||
|
--manifest_paths 'data/aishell/manifest.train' 'data/aishell/manifest.dev'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Build vocabulary failed. Terminated."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
# compute mean and stddev for normalizer
|
||||||
|
python tools/compute_mean_std.py \
|
||||||
|
--manifest_path='data/aishell/manifest.train' \
|
||||||
|
--num_samples=2000 \
|
||||||
|
--specgram_type='linear' \
|
||||||
|
--output_path='data/aishell/mean_std.npz'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Compute mean and stddev failed. Terminated."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
echo "Aishell data preparation done."
|
||||||
|
exit 0
|
@ -0,0 +1,46 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
cd ../.. > /dev/null
|
||||||
|
|
||||||
|
# download language model
|
||||||
|
cd models/lm > /dev/null
|
||||||
|
sh download_lm_ch.sh
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd - > /dev/null
|
||||||
|
|
||||||
|
|
||||||
|
# infer
|
||||||
|
CUDA_VISIBLE_DEVICES=0 \
|
||||||
|
python -u infer.py \
|
||||||
|
--num_samples=10 \
|
||||||
|
--trainer_count=1 \
|
||||||
|
--beam_size=300 \
|
||||||
|
--num_proc_bsearch=8 \
|
||||||
|
--num_conv_layers=2 \
|
||||||
|
--num_rnn_layers=3 \
|
||||||
|
--rnn_layer_size=1024 \
|
||||||
|
--alpha=2.6 \
|
||||||
|
--beta=5.0 \
|
||||||
|
--cutoff_prob=0.99 \
|
||||||
|
--cutoff_top_n=40 \
|
||||||
|
--use_gru=True \
|
||||||
|
--use_gpu=True \
|
||||||
|
--share_rnn_weights=False \
|
||||||
|
--infer_manifest='data/aishell/manifest.test' \
|
||||||
|
--mean_std_path='data/aishell/mean_std.npz' \
|
||||||
|
--vocab_path='data/aishell/vocab.txt' \
|
||||||
|
--model_path='checkpoints/aishell/params.latest.tar.gz' \
|
||||||
|
--lang_model_path='models/lm/zh_giga.no_cna_cmn.prune01244.klm' \
|
||||||
|
--decoding_method='ctc_beam_search' \
|
||||||
|
--error_rate_type='cer' \
|
||||||
|
--specgram_type='linear'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in inference!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,55 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
cd ../.. > /dev/null
|
||||||
|
|
||||||
|
# download language model
|
||||||
|
cd models/lm > /dev/null
|
||||||
|
sh download_lm_ch.sh
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd - > /dev/null
|
||||||
|
|
||||||
|
|
||||||
|
# download well-trained model
|
||||||
|
cd models/aishell > /dev/null
|
||||||
|
sh download_model.sh
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd - > /dev/null
|
||||||
|
|
||||||
|
|
||||||
|
# infer
|
||||||
|
CUDA_VISIBLE_DEVICES=0 \
|
||||||
|
python -u infer.py \
|
||||||
|
--num_samples=10 \
|
||||||
|
--trainer_count=1 \
|
||||||
|
--beam_size=300 \
|
||||||
|
--num_proc_bsearch=8 \
|
||||||
|
--num_conv_layers=2 \
|
||||||
|
--num_rnn_layers=3 \
|
||||||
|
--rnn_layer_size=1024 \
|
||||||
|
--alpha=2.6 \
|
||||||
|
--beta=5.0 \
|
||||||
|
--cutoff_prob=0.99 \
|
||||||
|
--cutoff_top_n=40 \
|
||||||
|
--use_gru=True \
|
||||||
|
--use_gpu=True \
|
||||||
|
--share_rnn_weights=False \
|
||||||
|
--infer_manifest='data/aishell/manifest.test' \
|
||||||
|
--mean_std_path='models/aishell/mean_std.npz' \
|
||||||
|
--vocab_path='models/aishell/vocab.txt' \
|
||||||
|
--model_path='models/aishell/params.tar.gz' \
|
||||||
|
--lang_model_path='models/lm/zh_giga.no_cna_cmn.prune01244.klm' \
|
||||||
|
--decoding_method='ctc_beam_search' \
|
||||||
|
--error_rate_type='cer' \
|
||||||
|
--specgram_type='linear'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in inference!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,47 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
cd ../.. > /dev/null
|
||||||
|
|
||||||
|
# download language model
|
||||||
|
cd models/lm > /dev/null
|
||||||
|
sh download_lm_ch.sh
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd - > /dev/null
|
||||||
|
|
||||||
|
|
||||||
|
# evaluate model
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||||
|
python -u test.py \
|
||||||
|
--batch_size=128 \
|
||||||
|
--trainer_count=8 \
|
||||||
|
--beam_size=300 \
|
||||||
|
--num_proc_bsearch=8 \
|
||||||
|
--num_proc_data=8 \
|
||||||
|
--num_conv_layers=2 \
|
||||||
|
--num_rnn_layers=3 \
|
||||||
|
--rnn_layer_size=1024 \
|
||||||
|
--alpha=2.6 \
|
||||||
|
--beta=5.0 \
|
||||||
|
--cutoff_prob=0.99 \
|
||||||
|
--cutoff_top_n=40 \
|
||||||
|
--use_gru=True \
|
||||||
|
--use_gpu=True \
|
||||||
|
--share_rnn_weights=False \
|
||||||
|
--test_manifest='data/aishell/manifest.test' \
|
||||||
|
--mean_std_path='data/aishell/mean_std.npz' \
|
||||||
|
--vocab_path='data/aishell/vocab.txt' \
|
||||||
|
--model_path='checkpoints/aishell/params.latest.tar.gz' \
|
||||||
|
--lang_model_path='models/lm/zh_giga.no_cna_cmn.prune01244.klm' \
|
||||||
|
--decoding_method='ctc_beam_search' \
|
||||||
|
--error_rate_type='cer' \
|
||||||
|
--specgram_type='linear'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in evaluation!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,56 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
cd ../.. > /dev/null
|
||||||
|
|
||||||
|
# download language model
|
||||||
|
cd models/lm > /dev/null
|
||||||
|
sh download_lm_ch.sh
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd - > /dev/null
|
||||||
|
|
||||||
|
|
||||||
|
# download well-trained model
|
||||||
|
cd models/aishell > /dev/null
|
||||||
|
sh download_model.sh
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd - > /dev/null
|
||||||
|
|
||||||
|
|
||||||
|
# evaluate model
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||||
|
python -u test.py \
|
||||||
|
--batch_size=128 \
|
||||||
|
--trainer_count=8 \
|
||||||
|
--beam_size=300 \
|
||||||
|
--num_proc_bsearch=8 \
|
||||||
|
--num_proc_data=8 \
|
||||||
|
--num_conv_layers=2 \
|
||||||
|
--num_rnn_layers=3 \
|
||||||
|
--rnn_layer_size=1024 \
|
||||||
|
--alpha=2.6 \
|
||||||
|
--beta=5.0 \
|
||||||
|
--cutoff_prob=0.99 \
|
||||||
|
--cutoff_top_n=40 \
|
||||||
|
--use_gru=True \
|
||||||
|
--use_gpu=True \
|
||||||
|
--share_rnn_weights=False \
|
||||||
|
--test_manifest='data/aishell/manifest.test' \
|
||||||
|
--mean_std_path='models/aishell/mean_std.npz' \
|
||||||
|
--vocab_path='models/aishell/vocab.txt' \
|
||||||
|
--model_path='models/aishell/params.tar.gz' \
|
||||||
|
--lang_model_path='models/lm/zh_giga.no_cna_cmn.prune01244.klm' \
|
||||||
|
--decoding_method='ctc_beam_search' \
|
||||||
|
--error_rate_type='cer' \
|
||||||
|
--specgram_type='linear'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in evaluation!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,41 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
cd ../.. > /dev/null
|
||||||
|
|
||||||
|
# train model
|
||||||
|
# if you wish to resume from an exists model, uncomment --init_model_path
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||||
|
python -u train.py \
|
||||||
|
--batch_size=64 \
|
||||||
|
--trainer_count=8 \
|
||||||
|
--num_passes=50 \
|
||||||
|
--num_proc_data=16 \
|
||||||
|
--num_conv_layers=2 \
|
||||||
|
--num_rnn_layers=3 \
|
||||||
|
--rnn_layer_size=1024 \
|
||||||
|
--num_iter_print=100 \
|
||||||
|
--learning_rate=5e-4 \
|
||||||
|
--max_duration=27.0 \
|
||||||
|
--min_duration=0.0 \
|
||||||
|
--test_off=False \
|
||||||
|
--use_sortagrad=True \
|
||||||
|
--use_gru=True \
|
||||||
|
--use_gpu=True \
|
||||||
|
--is_local=True \
|
||||||
|
--share_rnn_weights=False \
|
||||||
|
--train_manifest='data/aishell/manifest.train' \
|
||||||
|
--dev_manifest='data/aishell/manifest.dev' \
|
||||||
|
--mean_std_path='data/aishell/mean_std.npz' \
|
||||||
|
--vocab_path='data/aishell/vocab.txt' \
|
||||||
|
--output_model_dir='./checkpoints/aishell' \
|
||||||
|
--augment_conf_path='conf/augmentation.config' \
|
||||||
|
--specgram_type='linear' \
|
||||||
|
--shuffle_method='batch_shuffle_clipped'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in training!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,55 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
cd ../.. > /dev/null
|
||||||
|
|
||||||
|
# download language model
|
||||||
|
cd models/lm > /dev/null
|
||||||
|
sh download_lm_en.sh
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd - > /dev/null
|
||||||
|
|
||||||
|
|
||||||
|
# download well-trained model
|
||||||
|
cd models/baidu_en8k > /dev/null
|
||||||
|
sh download_model.sh
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd - > /dev/null
|
||||||
|
|
||||||
|
|
||||||
|
# infer
|
||||||
|
CUDA_VISIBLE_DEVICES=0 \
|
||||||
|
python -u infer.py \
|
||||||
|
--num_samples=10 \
|
||||||
|
--trainer_count=1 \
|
||||||
|
--beam_size=500 \
|
||||||
|
--num_proc_bsearch=5 \
|
||||||
|
--num_conv_layers=2 \
|
||||||
|
--num_rnn_layers=3 \
|
||||||
|
--rnn_layer_size=1024 \
|
||||||
|
--alpha=1.4 \
|
||||||
|
--beta=0.35 \
|
||||||
|
--cutoff_prob=1.0 \
|
||||||
|
--cutoff_top_n=40 \
|
||||||
|
--use_gru=True \
|
||||||
|
--use_gpu=True \
|
||||||
|
--share_rnn_weights=False \
|
||||||
|
--infer_manifest='data/librispeech/manifest.test-clean' \
|
||||||
|
--mean_std_path='models/baidu_en8k/mean_std.npz' \
|
||||||
|
--vocab_path='models/baidu_en8k/vocab.txt' \
|
||||||
|
--model_path='models/baidu_en8k/params.tar.gz' \
|
||||||
|
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||||
|
--decoding_method='ctc_beam_search' \
|
||||||
|
--error_rate_type='wer' \
|
||||||
|
--specgram_type='linear'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in inference!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,55 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
cd ../.. > /dev/null
|
||||||
|
|
||||||
|
# download language model
|
||||||
|
cd models/lm > /dev/null
|
||||||
|
sh download_lm_en.sh
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd - > /dev/null
|
||||||
|
|
||||||
|
|
||||||
|
# download well-trained model
|
||||||
|
cd models/baidu_en8k > /dev/null
|
||||||
|
sh download_model.sh
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd - > /dev/null
|
||||||
|
|
||||||
|
|
||||||
|
# evaluate model
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||||
|
python -u test.py \
|
||||||
|
--batch_size=128 \
|
||||||
|
--trainer_count=4 \
|
||||||
|
--beam_size=500 \
|
||||||
|
--num_proc_bsearch=8 \
|
||||||
|
--num_proc_data=8 \
|
||||||
|
--num_conv_layers=2 \
|
||||||
|
--num_rnn_layers=3 \
|
||||||
|
--rnn_layer_size=1024 \
|
||||||
|
--alpha=1.4 \
|
||||||
|
--beta=0.35 \
|
||||||
|
--cutoff_prob=1.0 \
|
||||||
|
--cutoff_top_n=40 \
|
||||||
|
--use_gru=True \
|
||||||
|
--use_gpu=True \
|
||||||
|
--share_rnn_weights=False \
|
||||||
|
--test_manifest='data/librispeech/manifest.test-clean' \
|
||||||
|
--mean_std_path='models/baidu_en8k/mean_std.npz' \
|
||||||
|
--vocab_path='models/baidu_en8k/vocab.txt' \
|
||||||
|
--model_path='models/baidu_en8k/params.tar.gz' \
|
||||||
|
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||||
|
--decoding_method='ctc_beam_search' \
|
||||||
|
--error_rate_type='wer' \
|
||||||
|
--specgram_type='linear'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in evaluation!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,17 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
cd ../.. > /dev/null
|
||||||
|
|
||||||
|
# start demo client
|
||||||
|
CUDA_VISIBLE_DEVICES=0 \
|
||||||
|
python -u deploy/demo_client.py \
|
||||||
|
--host_ip='localhost' \
|
||||||
|
--host_port=8086 \
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in starting demo client!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,54 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
# TODO: replace the model with a mandarin model
|
||||||
|
|
||||||
|
cd ../.. > /dev/null
|
||||||
|
|
||||||
|
# download language model
|
||||||
|
cd models/lm > /dev/null
|
||||||
|
sh download_lm_en.sh
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd - > /dev/null
|
||||||
|
|
||||||
|
|
||||||
|
# download well-trained model
|
||||||
|
cd models/baidu_en8k > /dev/null
|
||||||
|
sh download_model.sh
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd - > /dev/null
|
||||||
|
|
||||||
|
|
||||||
|
# start demo server
|
||||||
|
CUDA_VISIBLE_DEVICES=0 \
|
||||||
|
python -u deploy/demo_server.py \
|
||||||
|
--host_ip='localhost' \
|
||||||
|
--host_port=8086 \
|
||||||
|
--num_conv_layers=2 \
|
||||||
|
--num_rnn_layers=3 \
|
||||||
|
--rnn_layer_size=1024 \
|
||||||
|
--alpha=1.15 \
|
||||||
|
--beta=0.15 \
|
||||||
|
--cutoff_prob=1.0 \
|
||||||
|
--cutoff_top_n=40 \
|
||||||
|
--use_gru=True \
|
||||||
|
--use_gpu=True \
|
||||||
|
--share_rnn_weights=False \
|
||||||
|
--speech_save_dir='demo_cache' \
|
||||||
|
--warmup_manifest='data/tiny/manifest.test-clean' \
|
||||||
|
--mean_std_path='models/baidu_en8k/mean_std.npz' \
|
||||||
|
--vocab_path='models/baidu_en8k/vocab.txt' \
|
||||||
|
--model_path='models/baidu_en8k/params.tar.gz' \
|
||||||
|
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||||
|
--decoding_method='ctc_beam_search' \
|
||||||
|
--specgram_type='linear'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in starting demo server!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,45 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
cd ../.. > /dev/null
|
||||||
|
|
||||||
|
# download data, generate manifests
|
||||||
|
PYTHONPATH=.:$PYTHONPATH python data/librispeech/librispeech.py \
|
||||||
|
--manifest_prefix='data/librispeech/manifest' \
|
||||||
|
--target_dir='~/.cache/paddle/dataset/speech/libri' \
|
||||||
|
--full_download='True'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Prepare LibriSpeech failed. Terminated."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
cat data/librispeech/manifest.train-* | shuf > data/librispeech/manifest.train
|
||||||
|
|
||||||
|
|
||||||
|
# build vocabulary
|
||||||
|
python tools/build_vocab.py \
|
||||||
|
--count_threshold=0 \
|
||||||
|
--vocab_path='data/librispeech/vocab.txt' \
|
||||||
|
--manifest_paths='data/librispeech/manifest.train'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Build vocabulary failed. Terminated."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
# compute mean and stddev for normalizer
|
||||||
|
python tools/compute_mean_std.py \
|
||||||
|
--manifest_path='data/librispeech/manifest.train' \
|
||||||
|
--num_samples=2000 \
|
||||||
|
--specgram_type='linear' \
|
||||||
|
--output_path='data/librispeech/mean_std.npz'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Compute mean and stddev failed. Terminated."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
echo "LibriSpeech Data preparation done."
|
||||||
|
exit 0
|
@ -0,0 +1,46 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
cd ../.. > /dev/null
|
||||||
|
|
||||||
|
# download language model
|
||||||
|
cd models/lm > /dev/null
|
||||||
|
sh download_lm_en.sh
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd - > /dev/null
|
||||||
|
|
||||||
|
|
||||||
|
# infer
|
||||||
|
CUDA_VISIBLE_DEVICES=0 \
|
||||||
|
python -u infer.py \
|
||||||
|
--num_samples=10 \
|
||||||
|
--trainer_count=1 \
|
||||||
|
--beam_size=500 \
|
||||||
|
--num_proc_bsearch=8 \
|
||||||
|
--num_conv_layers=2 \
|
||||||
|
--num_rnn_layers=3 \
|
||||||
|
--rnn_layer_size=2048 \
|
||||||
|
--alpha=2.5 \
|
||||||
|
--beta=0.3 \
|
||||||
|
--cutoff_prob=1.0 \
|
||||||
|
--cutoff_top_n=40 \
|
||||||
|
--use_gru=False \
|
||||||
|
--use_gpu=True \
|
||||||
|
--share_rnn_weights=True \
|
||||||
|
--infer_manifest='data/librispeech/manifest.test-clean' \
|
||||||
|
--mean_std_path='data/librispeech/mean_std.npz' \
|
||||||
|
--vocab_path='data/librispeech/vocab.txt' \
|
||||||
|
--model_path='checkpoints/libri/params.latest.tar.gz' \
|
||||||
|
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||||
|
--decoding_method='ctc_beam_search' \
|
||||||
|
--error_rate_type='wer' \
|
||||||
|
--specgram_type='linear'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in inference!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,55 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
cd ../.. > /dev/null
|
||||||
|
|
||||||
|
# download language model
|
||||||
|
cd models/lm > /dev/null
|
||||||
|
sh download_lm_en.sh
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd - > /dev/null
|
||||||
|
|
||||||
|
|
||||||
|
# download well-trained model
|
||||||
|
cd models/librispeech > /dev/null
|
||||||
|
sh download_model.sh
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd - > /dev/null
|
||||||
|
|
||||||
|
|
||||||
|
# infer
|
||||||
|
CUDA_VISIBLE_DEVICES=0 \
|
||||||
|
python -u infer.py \
|
||||||
|
--num_samples=10 \
|
||||||
|
--trainer_count=1 \
|
||||||
|
--beam_size=500 \
|
||||||
|
--num_proc_bsearch=8 \
|
||||||
|
--num_conv_layers=2 \
|
||||||
|
--num_rnn_layers=3 \
|
||||||
|
--rnn_layer_size=2048 \
|
||||||
|
--alpha=2.5 \
|
||||||
|
--beta=0.3 \
|
||||||
|
--cutoff_prob=1.0 \
|
||||||
|
--cutoff_top_n=40 \
|
||||||
|
--use_gru=False \
|
||||||
|
--use_gpu=True \
|
||||||
|
--share_rnn_weights=True \
|
||||||
|
--infer_manifest='data/librispeech/manifest.test-clean' \
|
||||||
|
--mean_std_path='models/librispeech/mean_std.npz' \
|
||||||
|
--vocab_path='models/librispeech/vocab.txt' \
|
||||||
|
--model_path='models/librispeech/params.tar.gz' \
|
||||||
|
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||||
|
--decoding_method='ctc_beam_search' \
|
||||||
|
--error_rate_type='wer' \
|
||||||
|
--specgram_type='linear'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in inference!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,47 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
cd ../.. > /dev/null
|
||||||
|
|
||||||
|
# download language model
|
||||||
|
cd models/lm > /dev/null
|
||||||
|
sh download_lm_en.sh
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd - > /dev/null
|
||||||
|
|
||||||
|
|
||||||
|
# evaluate model
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||||
|
python -u test.py \
|
||||||
|
--batch_size=128 \
|
||||||
|
--trainer_count=8 \
|
||||||
|
--beam_size=500 \
|
||||||
|
--num_proc_bsearch=8 \
|
||||||
|
--num_proc_data=8 \
|
||||||
|
--num_conv_layers=2 \
|
||||||
|
--num_rnn_layers=3 \
|
||||||
|
--rnn_layer_size=2048 \
|
||||||
|
--alpha=2.5 \
|
||||||
|
--beta=0.3 \
|
||||||
|
--cutoff_prob=1.0 \
|
||||||
|
--cutoff_top_n=40 \
|
||||||
|
--use_gru=False \
|
||||||
|
--use_gpu=True \
|
||||||
|
--share_rnn_weights=True \
|
||||||
|
--test_manifest='data/librispeech/manifest.test-clean' \
|
||||||
|
--mean_std_path='data/librispeech/mean_std.npz' \
|
||||||
|
--vocab_path='data/librispeech/vocab.txt' \
|
||||||
|
--model_path='checkpoints/libri/params.latest.tar.gz' \
|
||||||
|
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||||
|
--decoding_method='ctc_beam_search' \
|
||||||
|
--error_rate_type='wer' \
|
||||||
|
--specgram_type='linear'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in evaluation!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,56 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
cd ../.. > /dev/null
|
||||||
|
|
||||||
|
# download language model
|
||||||
|
cd models/lm > /dev/null
|
||||||
|
sh download_lm_en.sh
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd - > /dev/null
|
||||||
|
|
||||||
|
|
||||||
|
# download well-trained model
|
||||||
|
cd models/librispeech > /dev/null
|
||||||
|
sh download_model.sh
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd - > /dev/null
|
||||||
|
|
||||||
|
|
||||||
|
# evaluate model
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||||
|
python -u test.py \
|
||||||
|
--batch_size=128 \
|
||||||
|
--trainer_count=8 \
|
||||||
|
--beam_size=500 \
|
||||||
|
--num_proc_bsearch=8 \
|
||||||
|
--num_proc_data=8 \
|
||||||
|
--num_conv_layers=2 \
|
||||||
|
--num_rnn_layers=3 \
|
||||||
|
--rnn_layer_size=2048 \
|
||||||
|
--alpha=2.5 \
|
||||||
|
--beta=0.3 \
|
||||||
|
--cutoff_prob=1.0 \
|
||||||
|
--cutoff_top_n=40 \
|
||||||
|
--use_gru=False \
|
||||||
|
--use_gpu=True \
|
||||||
|
--share_rnn_weights=True \
|
||||||
|
--test_manifest='data/librispeech/manifest.test-clean' \
|
||||||
|
--mean_std_path='models/librispeech/mean_std.npz' \
|
||||||
|
--vocab_path='models/librispeech/vocab.txt' \
|
||||||
|
--model_path='models/librispeech/params.tar.gz' \
|
||||||
|
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||||
|
--decoding_method='ctc_beam_search' \
|
||||||
|
--error_rate_type='wer' \
|
||||||
|
--specgram_type='linear'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in evaluation!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,41 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
cd ../.. > /dev/null
|
||||||
|
|
||||||
|
# train model
|
||||||
|
# if you wish to resume from an exists model, uncomment --init_model_path
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||||
|
python -u train.py \
|
||||||
|
--batch_size=160 \
|
||||||
|
--trainer_count=8 \
|
||||||
|
--num_passes=50 \
|
||||||
|
--num_proc_data=16 \
|
||||||
|
--num_conv_layers=2 \
|
||||||
|
--num_rnn_layers=3 \
|
||||||
|
--rnn_layer_size=2048 \
|
||||||
|
--num_iter_print=100 \
|
||||||
|
--learning_rate=5e-4 \
|
||||||
|
--max_duration=27.0 \
|
||||||
|
--min_duration=0.0 \
|
||||||
|
--test_off=False \
|
||||||
|
--use_sortagrad=True \
|
||||||
|
--use_gru=False \
|
||||||
|
--use_gpu=True \
|
||||||
|
--is_local=True \
|
||||||
|
--share_rnn_weights=True \
|
||||||
|
--train_manifest='data/librispeech/manifest.train' \
|
||||||
|
--dev_manifest='data/librispeech/manifest.dev-clean' \
|
||||||
|
--mean_std_path='data/librispeech/mean_std.npz' \
|
||||||
|
--vocab_path='data/librispeech/vocab.txt' \
|
||||||
|
--output_model_dir='./checkpoints/libri' \
|
||||||
|
--augment_conf_path='conf/augmentation.config' \
|
||||||
|
--specgram_type='linear' \
|
||||||
|
--shuffle_method='batch_shuffle_clipped'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in training!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,41 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
cd ../.. > /dev/null
|
||||||
|
|
||||||
|
# grid-search for hyper-parameters in language model
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||||
|
python -u tools/tune.py \
|
||||||
|
--num_batches=-1 \
|
||||||
|
--batch_size=128 \
|
||||||
|
--trainer_count=4 \
|
||||||
|
--beam_size=500 \
|
||||||
|
--num_proc_bsearch=12 \
|
||||||
|
--num_conv_layers=2 \
|
||||||
|
--num_rnn_layers=3 \
|
||||||
|
--rnn_layer_size=2048 \
|
||||||
|
--num_alphas=45 \
|
||||||
|
--num_betas=8 \
|
||||||
|
--alpha_from=1.0 \
|
||||||
|
--alpha_to=3.2 \
|
||||||
|
--beta_from=0.1 \
|
||||||
|
--beta_to=0.45 \
|
||||||
|
--cutoff_prob=1.0 \
|
||||||
|
--cutoff_top_n=40 \
|
||||||
|
--use_gru=False \
|
||||||
|
--use_gpu=True \
|
||||||
|
--share_rnn_weights=True \
|
||||||
|
--tune_manifest='data/librispeech/manifest.dev-clean' \
|
||||||
|
--mean_std_path='data/librispeech/mean_std.npz' \
|
||||||
|
--vocab_path='models/librispeech/vocab.txt' \
|
||||||
|
--model_path='models/librispeech/params.tar.gz' \
|
||||||
|
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||||
|
--error_rate_type='wer' \
|
||||||
|
--specgram_type='linear'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in tuning!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,51 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
cd ../.. > /dev/null
|
||||||
|
|
||||||
|
# prepare folder
|
||||||
|
if [ ! -e data/tiny ]; then
|
||||||
|
mkdir data/tiny
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
# download data, generate manifests
|
||||||
|
PYTHONPATH=.:$PYTHONPATH python data/librispeech/librispeech.py \
|
||||||
|
--manifest_prefix='data/tiny/manifest' \
|
||||||
|
--target_dir='~/.cache/paddle/dataset/speech/libri' \
|
||||||
|
--full_download='False'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Prepare LibriSpeech failed. Terminated."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
head -n 64 data/tiny/manifest.dev-clean > data/tiny/manifest.tiny
|
||||||
|
|
||||||
|
|
||||||
|
# build vocabulary
|
||||||
|
python tools/build_vocab.py \
|
||||||
|
--count_threshold=0 \
|
||||||
|
--vocab_path='data/tiny/vocab.txt' \
|
||||||
|
--manifest_paths='data/tiny/manifest.dev-clean'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Build vocabulary failed. Terminated."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
# compute mean and stddev for normalizer
|
||||||
|
python tools/compute_mean_std.py \
|
||||||
|
--manifest_path='data/tiny/manifest.tiny' \
|
||||||
|
--num_samples=64 \
|
||||||
|
--specgram_type='linear' \
|
||||||
|
--output_path='data/tiny/mean_std.npz'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Compute mean and stddev failed. Terminated."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
echo "Tiny data preparation done."
|
||||||
|
exit 0
|
@ -0,0 +1,46 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
cd ../.. > /dev/null
|
||||||
|
|
||||||
|
# download language model
|
||||||
|
cd models/lm > /dev/null
|
||||||
|
sh download_lm_en.sh
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd - > /dev/null
|
||||||
|
|
||||||
|
|
||||||
|
# infer
|
||||||
|
CUDA_VISIBLE_DEVICES=0 \
|
||||||
|
python -u infer.py \
|
||||||
|
--num_samples=10 \
|
||||||
|
--trainer_count=1 \
|
||||||
|
--beam_size=500 \
|
||||||
|
--num_proc_bsearch=8 \
|
||||||
|
--num_conv_layers=2 \
|
||||||
|
--num_rnn_layers=3 \
|
||||||
|
--rnn_layer_size=2048 \
|
||||||
|
--alpha=2.5 \
|
||||||
|
--beta=0.3 \
|
||||||
|
--cutoff_prob=1.0 \
|
||||||
|
--cutoff_top_n=40 \
|
||||||
|
--use_gru=False \
|
||||||
|
--use_gpu=True \
|
||||||
|
--share_rnn_weights=True \
|
||||||
|
--infer_manifest='data/tiny/manifest.tiny' \
|
||||||
|
--mean_std_path='data/tiny/mean_std.npz' \
|
||||||
|
--vocab_path='data/tiny/vocab.txt' \
|
||||||
|
--model_path='checkpoints/tiny/params.pass-19.tar.gz' \
|
||||||
|
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||||
|
--decoding_method='ctc_beam_search' \
|
||||||
|
--error_rate_type='wer' \
|
||||||
|
--specgram_type='linear'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in inference!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,55 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
cd ../.. > /dev/null
|
||||||
|
|
||||||
|
# download language model
|
||||||
|
cd models/lm > /dev/null
|
||||||
|
sh download_lm_en.sh
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd - > /dev/null
|
||||||
|
|
||||||
|
|
||||||
|
# download well-trained model
|
||||||
|
cd models/librispeech > /dev/null
|
||||||
|
sh download_model.sh
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd - > /dev/null
|
||||||
|
|
||||||
|
|
||||||
|
# infer
|
||||||
|
CUDA_VISIBLE_DEVICES=0 \
|
||||||
|
python -u infer.py \
|
||||||
|
--num_samples=10 \
|
||||||
|
--trainer_count=1 \
|
||||||
|
--beam_size=500 \
|
||||||
|
--num_proc_bsearch=8 \
|
||||||
|
--num_conv_layers=2 \
|
||||||
|
--num_rnn_layers=3 \
|
||||||
|
--rnn_layer_size=2048 \
|
||||||
|
--alpha=2.5 \
|
||||||
|
--beta=0.3 \
|
||||||
|
--cutoff_prob=1.0 \
|
||||||
|
--cutoff_top_n=40 \
|
||||||
|
--use_gru=False \
|
||||||
|
--use_gpu=True \
|
||||||
|
--share_rnn_weights=True \
|
||||||
|
--infer_manifest='data/tiny/manifest.test-clean' \
|
||||||
|
--mean_std_path='models/librispeech/mean_std.npz' \
|
||||||
|
--vocab_path='models/librispeech/vocab.txt' \
|
||||||
|
--model_path='models/librispeech/params.tar.gz' \
|
||||||
|
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||||
|
--decoding_method='ctc_beam_search' \
|
||||||
|
--error_rate_type='wer' \
|
||||||
|
--specgram_type='linear'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in inference!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,47 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
cd ../.. > /dev/null
|
||||||
|
|
||||||
|
# download language model
|
||||||
|
cd models/lm > /dev/null
|
||||||
|
sh download_lm_en.sh
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd - > /dev/null
|
||||||
|
|
||||||
|
|
||||||
|
# evaluate model
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||||
|
python -u test.py \
|
||||||
|
--batch_size=16 \
|
||||||
|
--trainer_count=8 \
|
||||||
|
--beam_size=500 \
|
||||||
|
--num_proc_bsearch=8 \
|
||||||
|
--num_proc_data=8 \
|
||||||
|
--num_conv_layers=2 \
|
||||||
|
--num_rnn_layers=3 \
|
||||||
|
--rnn_layer_size=2048 \
|
||||||
|
--alpha=2.5 \
|
||||||
|
--beta=0.3 \
|
||||||
|
--cutoff_prob=1.0 \
|
||||||
|
--cutoff_top_n=40 \
|
||||||
|
--use_gru=False \
|
||||||
|
--use_gpu=True \
|
||||||
|
--share_rnn_weights=True \
|
||||||
|
--test_manifest='data/tiny/manifest.tiny' \
|
||||||
|
--mean_std_path='data/tiny/mean_std.npz' \
|
||||||
|
--vocab_path='data/tiny/vocab.txt' \
|
||||||
|
--model_path='checkpoints/tiny/params.pass-19.tar.gz' \
|
||||||
|
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||||
|
--decoding_method='ctc_beam_search' \
|
||||||
|
--error_rate_type='wer' \
|
||||||
|
--specgram_type='linear'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in evaluation!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,56 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
cd ../.. > /dev/null
|
||||||
|
|
||||||
|
# download language model
|
||||||
|
cd models/lm > /dev/null
|
||||||
|
sh download_lm_en.sh
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd - > /dev/null
|
||||||
|
|
||||||
|
|
||||||
|
# download well-trained model
|
||||||
|
cd models/librispeech > /dev/null
|
||||||
|
sh download_model.sh
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
cd - > /dev/null
|
||||||
|
|
||||||
|
|
||||||
|
# evaluate model
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||||
|
python -u test.py \
|
||||||
|
--batch_size=128 \
|
||||||
|
--trainer_count=8 \
|
||||||
|
--beam_size=500 \
|
||||||
|
--num_proc_bsearch=8 \
|
||||||
|
--num_proc_data=8 \
|
||||||
|
--num_conv_layers=2 \
|
||||||
|
--num_rnn_layers=3 \
|
||||||
|
--rnn_layer_size=2048 \
|
||||||
|
--alpha=2.5 \
|
||||||
|
--beta=0.3 \
|
||||||
|
--cutoff_prob=1.0 \
|
||||||
|
--cutoff_top_n=40 \
|
||||||
|
--use_gru=False \
|
||||||
|
--use_gpu=True \
|
||||||
|
--share_rnn_weights=True \
|
||||||
|
--test_manifest='data/tiny/manifest.test-clean' \
|
||||||
|
--mean_std_path='models/librispeech/mean_std.npz' \
|
||||||
|
--vocab_path='models/librispeech/vocab.txt' \
|
||||||
|
--model_path='models/librispeech/params.tar.gz' \
|
||||||
|
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||||
|
--decoding_method='ctc_beam_search' \
|
||||||
|
--error_rate_type='wer' \
|
||||||
|
--specgram_type='linear'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in evaluation!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,41 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
cd ../.. > /dev/null
|
||||||
|
|
||||||
|
# train model
|
||||||
|
# if you wish to resume from an exists model, uncomment --init_model_path
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3 \
|
||||||
|
python -u train.py \
|
||||||
|
--batch_size=16 \
|
||||||
|
--trainer_count=4 \
|
||||||
|
--num_passes=20 \
|
||||||
|
--num_proc_data=1 \
|
||||||
|
--num_conv_layers=2 \
|
||||||
|
--num_rnn_layers=3 \
|
||||||
|
--rnn_layer_size=2048 \
|
||||||
|
--num_iter_print=100 \
|
||||||
|
--learning_rate=1e-5 \
|
||||||
|
--max_duration=27.0 \
|
||||||
|
--min_duration=0.0 \
|
||||||
|
--test_off=False \
|
||||||
|
--use_sortagrad=True \
|
||||||
|
--use_gru=False \
|
||||||
|
--use_gpu=True \
|
||||||
|
--is_local=True \
|
||||||
|
--share_rnn_weights=True \
|
||||||
|
--train_manifest='data/tiny/manifest.tiny' \
|
||||||
|
--dev_manifest='data/tiny/manifest.tiny' \
|
||||||
|
--mean_std_path='data/tiny/mean_std.npz' \
|
||||||
|
--vocab_path='data/tiny/vocab.txt' \
|
||||||
|
--output_model_dir='./checkpoints/tiny' \
|
||||||
|
--augment_conf_path='conf/augmentation.config' \
|
||||||
|
--specgram_type='linear' \
|
||||||
|
--shuffle_method='batch_shuffle_clipped'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Fail in training!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,41 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
cd ../.. > /dev/null
|
||||||
|
|
||||||
|
# grid-search for hyper-parameters in language model
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||||
|
python -u tools/tune.py \
|
||||||
|
--num_batches=1 \
|
||||||
|
--batch_size=24 \
|
||||||
|
--trainer_count=8 \
|
||||||
|
--beam_size=500 \
|
||||||
|
--num_proc_bsearch=12 \
|
||||||
|
--num_conv_layers=2 \
|
||||||
|
--num_rnn_layers=3 \
|
||||||
|
--rnn_layer_size=2048 \
|
||||||
|
--num_alphas=45 \
|
||||||
|
--num_betas=8 \
|
||||||
|
--alpha_from=1.0 \
|
||||||
|
--alpha_to=3.2 \
|
||||||
|
--beta_from=0.1 \
|
||||||
|
--beta_to=0.45 \
|
||||||
|
--cutoff_prob=1.0 \
|
||||||
|
--cutoff_top_n=40 \
|
||||||
|
--use_gru=False \
|
||||||
|
--use_gpu=True \
|
||||||
|
--share_rnn_weights=True \
|
||||||
|
--tune_manifest='data/tiny/manifest.tiny' \
|
||||||
|
--mean_std_path='data/tiny/mean_std.npz' \
|
||||||
|
--vocab_path='data/tiny/vocab.txt' \
|
||||||
|
--model_path='checkpoints/params.pass-9.tar.gz' \
|
||||||
|
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
||||||
|
--error_rate_type='wer' \
|
||||||
|
--specgram_type='linear'
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in tuning!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,135 @@
|
|||||||
|
"""Inferer for DeepSpeech2 model."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import functools
|
||||||
|
import paddle.v2 as paddle
|
||||||
|
from data_utils.data import DataGenerator
|
||||||
|
from model_utils.model import DeepSpeech2Model
|
||||||
|
from utils.error_rate import wer, cer
|
||||||
|
from utils.utility import add_arguments, print_arguments
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
add_arg = functools.partial(add_arguments, argparser=parser)
|
||||||
|
# yapf: disable
|
||||||
|
add_arg('num_samples', int, 10, "# of samples to infer.")
|
||||||
|
add_arg('trainer_count', int, 8, "# of Trainers (CPUs or GPUs).")
|
||||||
|
add_arg('beam_size', int, 500, "Beam search width.")
|
||||||
|
add_arg('num_proc_bsearch', int, 8, "# of CPUs for beam search.")
|
||||||
|
add_arg('num_conv_layers', int, 2, "# of convolution layers.")
|
||||||
|
add_arg('num_rnn_layers', int, 3, "# of recurrent layers.")
|
||||||
|
add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.")
|
||||||
|
add_arg('alpha', float, 2.5, "Coef of LM for beam search.")
|
||||||
|
add_arg('beta', float, 0.3, "Coef of WC for beam search.")
|
||||||
|
add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.")
|
||||||
|
add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.")
|
||||||
|
add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.")
|
||||||
|
add_arg('use_gpu', bool, True, "Use GPU or not.")
|
||||||
|
add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across "
|
||||||
|
"bi-directional RNNs. Not for GRU.")
|
||||||
|
add_arg('infer_manifest', str,
|
||||||
|
'data/librispeech/manifest.dev-clean',
|
||||||
|
"Filepath of manifest to infer.")
|
||||||
|
add_arg('mean_std_path', str,
|
||||||
|
'data/librispeech/mean_std.npz',
|
||||||
|
"Filepath of normalizer's mean & std.")
|
||||||
|
add_arg('vocab_path', str,
|
||||||
|
'data/librispeech/vocab.txt',
|
||||||
|
"Filepath of vocabulary.")
|
||||||
|
add_arg('lang_model_path', str,
|
||||||
|
'models/lm/common_crawl_00.prune01111.trie.klm',
|
||||||
|
"Filepath for language model.")
|
||||||
|
add_arg('model_path', str,
|
||||||
|
'./checkpoints/libri/params.latest.tar.gz',
|
||||||
|
"If None, the training starts from scratch, "
|
||||||
|
"otherwise, it resumes from the pre-trained model.")
|
||||||
|
add_arg('decoding_method', str,
|
||||||
|
'ctc_beam_search',
|
||||||
|
"Decoding method. Options: ctc_beam_search, ctc_greedy",
|
||||||
|
choices = ['ctc_beam_search', 'ctc_greedy'])
|
||||||
|
add_arg('error_rate_type', str,
|
||||||
|
'wer',
|
||||||
|
"Error rate type for evaluation.",
|
||||||
|
choices=['wer', 'cer'])
|
||||||
|
add_arg('specgram_type', str,
|
||||||
|
'linear',
|
||||||
|
"Audio feature type. Options: linear, mfcc.",
|
||||||
|
choices=['linear', 'mfcc'])
|
||||||
|
# yapf: disable
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def infer():
|
||||||
|
"""Inference for DeepSpeech2."""
|
||||||
|
data_generator = DataGenerator(
|
||||||
|
vocab_filepath=args.vocab_path,
|
||||||
|
mean_std_filepath=args.mean_std_path,
|
||||||
|
augmentation_config='{}',
|
||||||
|
specgram_type=args.specgram_type,
|
||||||
|
num_threads=1,
|
||||||
|
keep_transcription_text=True)
|
||||||
|
batch_reader = data_generator.batch_reader_creator(
|
||||||
|
manifest_path=args.infer_manifest,
|
||||||
|
batch_size=args.num_samples,
|
||||||
|
min_batch_size=1,
|
||||||
|
sortagrad=False,
|
||||||
|
shuffle_method=None)
|
||||||
|
infer_data = batch_reader().next()
|
||||||
|
|
||||||
|
ds2_model = DeepSpeech2Model(
|
||||||
|
vocab_size=data_generator.vocab_size,
|
||||||
|
num_conv_layers=args.num_conv_layers,
|
||||||
|
num_rnn_layers=args.num_rnn_layers,
|
||||||
|
rnn_layer_size=args.rnn_layer_size,
|
||||||
|
use_gru=args.use_gru,
|
||||||
|
pretrained_model_path=args.model_path,
|
||||||
|
share_rnn_weights=args.share_rnn_weights)
|
||||||
|
|
||||||
|
# decoders only accept string encoded in utf-8
|
||||||
|
vocab_list = [chars.encode("utf-8") for chars in data_generator.vocab_list]
|
||||||
|
|
||||||
|
if args.decoding_method == "ctc_greedy":
|
||||||
|
ds2_model.logger.info("start inference ...")
|
||||||
|
probs_split = ds2_model.infer_batch_probs(infer_data=infer_data,
|
||||||
|
feeding_dict=data_generator.feeding)
|
||||||
|
result_transcripts = ds2_model.decode_batch_greedy(
|
||||||
|
probs_split=probs_split,
|
||||||
|
vocab_list=vocab_list)
|
||||||
|
else:
|
||||||
|
ds2_model.init_ext_scorer(args.alpha, args.beta, args.lang_model_path,
|
||||||
|
vocab_list)
|
||||||
|
ds2_model.logger.info("start inference ...")
|
||||||
|
probs_split = ds2_model.infer_batch_probs(infer_data=infer_data,
|
||||||
|
feeding_dict=data_generator.feeding)
|
||||||
|
result_transcripts = ds2_model.decode_batch_beam_search(
|
||||||
|
probs_split=probs_split,
|
||||||
|
beam_alpha=args.alpha,
|
||||||
|
beam_beta=args.beta,
|
||||||
|
beam_size=args.beam_size,
|
||||||
|
cutoff_prob=args.cutoff_prob,
|
||||||
|
cutoff_top_n=args.cutoff_top_n,
|
||||||
|
vocab_list=vocab_list,
|
||||||
|
num_processes=args.num_proc_bsearch)
|
||||||
|
|
||||||
|
error_rate_func = cer if args.error_rate_type == 'cer' else wer
|
||||||
|
target_transcripts = [data[1] for data in infer_data]
|
||||||
|
for target, result in zip(target_transcripts, result_transcripts):
|
||||||
|
print("\nTarget Transcription: %s\nOutput Transcription: %s" %
|
||||||
|
(target, result))
|
||||||
|
print("Current error rate [%s] = %f" %
|
||||||
|
(args.error_rate_type, error_rate_func(target, result)))
|
||||||
|
|
||||||
|
ds2_model.logger.info("finish inference")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print_arguments(args)
|
||||||
|
paddle.init(use_gpu=args.use_gpu,
|
||||||
|
rnn_use_batch=True,
|
||||||
|
trainer_count=args.trainer_count)
|
||||||
|
infer()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -0,0 +1,442 @@
|
|||||||
|
"""Contains DeepSpeech2 model."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
import gzip
|
||||||
|
import copy
|
||||||
|
import inspect
|
||||||
|
from distutils.dir_util import mkpath
|
||||||
|
import paddle.v2 as paddle
|
||||||
|
from decoders.swig_wrapper import Scorer
|
||||||
|
from decoders.swig_wrapper import ctc_greedy_decoder
|
||||||
|
from decoders.swig_wrapper import ctc_beam_search_decoder_batch
|
||||||
|
from model_utils.network import deep_speech_v2_network
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
format='[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s')
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSpeech2Model(object):
|
||||||
|
"""DeepSpeech2Model class.
|
||||||
|
|
||||||
|
:param vocab_size: Decoding vocabulary size.
|
||||||
|
:type vocab_size: int
|
||||||
|
:param num_conv_layers: Number of stacking convolution layers.
|
||||||
|
:type num_conv_layers: int
|
||||||
|
:param num_rnn_layers: Number of stacking RNN layers.
|
||||||
|
:type num_rnn_layers: int
|
||||||
|
:param rnn_layer_size: RNN layer size (number of RNN cells).
|
||||||
|
:type rnn_layer_size: int
|
||||||
|
:param pretrained_model_path: Pretrained model path. If None, will train
|
||||||
|
from stratch.
|
||||||
|
:type pretrained_model_path: basestring|None
|
||||||
|
:param share_rnn_weights: Whether to share input-hidden weights between
|
||||||
|
forward and backward directional RNNs.Notice that
|
||||||
|
for GRU, weight sharing is not supported.
|
||||||
|
:type share_rnn_weights: bool
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, vocab_size, num_conv_layers, num_rnn_layers,
|
||||||
|
rnn_layer_size, use_gru, pretrained_model_path,
|
||||||
|
share_rnn_weights):
|
||||||
|
self._create_network(vocab_size, num_conv_layers, num_rnn_layers,
|
||||||
|
rnn_layer_size, use_gru, share_rnn_weights)
|
||||||
|
self._create_parameters(pretrained_model_path)
|
||||||
|
self._inferer = None
|
||||||
|
self._loss_inferer = None
|
||||||
|
self._ext_scorer = None
|
||||||
|
self._num_conv_layers = num_conv_layers
|
||||||
|
self.logger = logging.getLogger("")
|
||||||
|
self.logger.setLevel(level=logging.INFO)
|
||||||
|
|
||||||
|
def train(self,
|
||||||
|
train_batch_reader,
|
||||||
|
dev_batch_reader,
|
||||||
|
feeding_dict,
|
||||||
|
learning_rate,
|
||||||
|
gradient_clipping,
|
||||||
|
num_passes,
|
||||||
|
output_model_dir,
|
||||||
|
is_local=True,
|
||||||
|
num_iterations_print=100,
|
||||||
|
test_off=False):
|
||||||
|
"""Train the model.
|
||||||
|
|
||||||
|
:param train_batch_reader: Train data reader.
|
||||||
|
:type train_batch_reader: callable
|
||||||
|
:param dev_batch_reader: Validation data reader.
|
||||||
|
:type dev_batch_reader: callable
|
||||||
|
:param feeding_dict: Feeding is a map of field name and tuple index
|
||||||
|
of the data that reader returns.
|
||||||
|
:type feeding_dict: dict|list
|
||||||
|
:param learning_rate: Learning rate for ADAM optimizer.
|
||||||
|
:type learning_rate: float
|
||||||
|
:param gradient_clipping: Gradient clipping threshold.
|
||||||
|
:type gradient_clipping: float
|
||||||
|
:param num_passes: Number of training epochs.
|
||||||
|
:type num_passes: int
|
||||||
|
:param num_iterations_print: Number of training iterations for printing
|
||||||
|
a training loss.
|
||||||
|
:type rnn_iteratons_print: int
|
||||||
|
:param is_local: Set to False if running with pserver with multi-nodes.
|
||||||
|
:type is_local: bool
|
||||||
|
:param output_model_dir: Directory for saving the model (every pass).
|
||||||
|
:type output_model_dir: basestring
|
||||||
|
:param test_off: Turn off testing.
|
||||||
|
:type test_off: bool
|
||||||
|
"""
|
||||||
|
# prepare model output directory
|
||||||
|
if not os.path.exists(output_model_dir):
|
||||||
|
mkpath(output_model_dir)
|
||||||
|
|
||||||
|
# adapt the feeding dict and reader according to the network
|
||||||
|
adapted_feeding_dict = self._adapt_feeding_dict(feeding_dict)
|
||||||
|
adapted_train_batch_reader = self._adapt_data(train_batch_reader)
|
||||||
|
adapted_dev_batch_reader = self._adapt_data(dev_batch_reader)
|
||||||
|
|
||||||
|
# prepare optimizer and trainer
|
||||||
|
optimizer = paddle.optimizer.Adam(
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
gradient_clipping_threshold=gradient_clipping)
|
||||||
|
trainer = paddle.trainer.SGD(
|
||||||
|
cost=self._loss,
|
||||||
|
parameters=self._parameters,
|
||||||
|
update_equation=optimizer,
|
||||||
|
is_local=is_local)
|
||||||
|
|
||||||
|
# create event handler
|
||||||
|
def event_handler(event):
|
||||||
|
global start_time, cost_sum, cost_counter
|
||||||
|
if isinstance(event, paddle.event.EndIteration):
|
||||||
|
cost_sum += event.cost
|
||||||
|
cost_counter += 1
|
||||||
|
if (event.batch_id + 1) % num_iterations_print == 0:
|
||||||
|
output_model_path = os.path.join(output_model_dir,
|
||||||
|
"params.latest.tar.gz")
|
||||||
|
with gzip.open(output_model_path, 'w') as f:
|
||||||
|
trainer.save_parameter_to_tar(f)
|
||||||
|
print("\nPass: %d, Batch: %d, TrainCost: %f" %
|
||||||
|
(event.pass_id, event.batch_id + 1,
|
||||||
|
cost_sum / cost_counter))
|
||||||
|
cost_sum, cost_counter = 0.0, 0
|
||||||
|
else:
|
||||||
|
sys.stdout.write('.')
|
||||||
|
sys.stdout.flush()
|
||||||
|
if isinstance(event, paddle.event.BeginPass):
|
||||||
|
start_time = time.time()
|
||||||
|
cost_sum, cost_counter = 0.0, 0
|
||||||
|
if isinstance(event, paddle.event.EndPass):
|
||||||
|
if test_off:
|
||||||
|
print("\n------- Time: %d sec, Pass: %d" %
|
||||||
|
(time.time() - start_time, event.pass_id))
|
||||||
|
else:
|
||||||
|
result = trainer.test(
|
||||||
|
reader=adapted_dev_batch_reader,
|
||||||
|
feeding=adapted_feeding_dict)
|
||||||
|
print(
|
||||||
|
"\n------- Time: %d sec, Pass: %d, "
|
||||||
|
"ValidationCost: %s" %
|
||||||
|
(time.time() - start_time, event.pass_id, result.cost))
|
||||||
|
output_model_path = os.path.join(
|
||||||
|
output_model_dir, "params.pass-%d.tar.gz" % event.pass_id)
|
||||||
|
with gzip.open(output_model_path, 'w') as f:
|
||||||
|
trainer.save_parameter_to_tar(f)
|
||||||
|
|
||||||
|
# run train
|
||||||
|
trainer.train(
|
||||||
|
reader=adapted_train_batch_reader,
|
||||||
|
event_handler=event_handler,
|
||||||
|
num_passes=num_passes,
|
||||||
|
feeding=adapted_feeding_dict)
|
||||||
|
|
||||||
|
# TODO(@pkuyym) merge this function into infer_batch
|
||||||
|
def infer_loss_batch(self, infer_data):
|
||||||
|
"""Model inference. Infer the ctc loss for a batch of speech
|
||||||
|
utterances.
|
||||||
|
|
||||||
|
:param infer_data: List of utterances to infer, with each utterance a
|
||||||
|
tuple of audio features and transcription text (empty
|
||||||
|
string).
|
||||||
|
:type infer_data: list
|
||||||
|
:return: List of ctc loss.
|
||||||
|
:rtype: List of float
|
||||||
|
"""
|
||||||
|
# define inferer
|
||||||
|
if self._loss_inferer == None:
|
||||||
|
self._loss_inferer = paddle.inference.Inference(
|
||||||
|
output_layer=self._loss, parameters=self._parameters)
|
||||||
|
# run inference
|
||||||
|
return self._loss_inferer.infer(input=infer_data)
|
||||||
|
|
||||||
|
def infer_batch_probs(self, infer_data, feeding_dict):
|
||||||
|
"""Infer the prob matrices for a batch of speech utterances.
|
||||||
|
|
||||||
|
:param infer_data: List of utterances to infer, with each utterance
|
||||||
|
consisting of a tuple of audio features and
|
||||||
|
transcription text (empty string).
|
||||||
|
:type infer_data: list
|
||||||
|
:param feeding_dict: Feeding is a map of field name and tuple index
|
||||||
|
of the data that reader returns.
|
||||||
|
:type feeding_dict: dict|list
|
||||||
|
:return: List of 2-D probability matrix, and each consists of prob
|
||||||
|
vectors for one speech utterancce.
|
||||||
|
:rtype: List of matrix
|
||||||
|
"""
|
||||||
|
# define inferer
|
||||||
|
if self._inferer == None:
|
||||||
|
self._inferer = paddle.inference.Inference(
|
||||||
|
output_layer=self._log_probs, parameters=self._parameters)
|
||||||
|
adapted_feeding_dict = self._adapt_feeding_dict(feeding_dict)
|
||||||
|
adapted_infer_data = self._adapt_data(infer_data)
|
||||||
|
# run inference
|
||||||
|
infer_results = self._inferer.infer(
|
||||||
|
input=adapted_infer_data, feeding=adapted_feeding_dict)
|
||||||
|
start_pos = [0] * (len(adapted_infer_data) + 1)
|
||||||
|
for i in xrange(len(adapted_infer_data)):
|
||||||
|
start_pos[i + 1] = start_pos[i] + adapted_infer_data[i][3][0]
|
||||||
|
probs_split = [
|
||||||
|
infer_results[start_pos[i]:start_pos[i + 1]]
|
||||||
|
for i in xrange(0, len(adapted_infer_data))
|
||||||
|
]
|
||||||
|
return probs_split
|
||||||
|
|
||||||
|
def decode_batch_greedy(self, probs_split, vocab_list):
|
||||||
|
"""Decode by best path for a batch of probs matrix input.
|
||||||
|
|
||||||
|
:param probs_split: List of 2-D probability matrix, and each consists
|
||||||
|
of prob vectors for one speech utterancce.
|
||||||
|
:param probs_split: List of matrix
|
||||||
|
:param vocab_list: List of tokens in the vocabulary, for decoding.
|
||||||
|
:type vocab_list: list
|
||||||
|
:return: List of transcription texts.
|
||||||
|
:rtype: List of basestring
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for i, probs in enumerate(probs_split):
|
||||||
|
output_transcription = ctc_greedy_decoder(
|
||||||
|
probs_seq=probs, vocabulary=vocab_list)
|
||||||
|
results.append(output_transcription)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def init_ext_scorer(self, beam_alpha, beam_beta, language_model_path,
|
||||||
|
vocab_list):
|
||||||
|
"""Initialize the external scorer.
|
||||||
|
|
||||||
|
:param beam_alpha: Parameter associated with language model.
|
||||||
|
:type beam_alpha: float
|
||||||
|
:param beam_beta: Parameter associated with word count.
|
||||||
|
:type beam_beta: float
|
||||||
|
:param language_model_path: Filepath for language model. If it is
|
||||||
|
empty, the external scorer will be set to
|
||||||
|
None, and the decoding method will be pure
|
||||||
|
beam search without scorer.
|
||||||
|
:type language_model_path: basestring|None
|
||||||
|
:param vocab_list: List of tokens in the vocabulary, for decoding.
|
||||||
|
:type vocab_list: list
|
||||||
|
"""
|
||||||
|
if language_model_path != '':
|
||||||
|
self.logger.info("begin to initialize the external scorer "
|
||||||
|
"for decoding")
|
||||||
|
self._ext_scorer = Scorer(beam_alpha, beam_beta,
|
||||||
|
language_model_path, vocab_list)
|
||||||
|
lm_char_based = self._ext_scorer.is_character_based()
|
||||||
|
lm_max_order = self._ext_scorer.get_max_order()
|
||||||
|
lm_dict_size = self._ext_scorer.get_dict_size()
|
||||||
|
self.logger.info("language model: "
|
||||||
|
"is_character_based = %d," % lm_char_based +
|
||||||
|
" max_order = %d," % lm_max_order +
|
||||||
|
" dict_size = %d" % lm_dict_size)
|
||||||
|
self.logger.info("end initializing scorer")
|
||||||
|
else:
|
||||||
|
self._ext_scorer = None
|
||||||
|
self.logger.info("no language model provided, "
|
||||||
|
"decoding by pure beam search without scorer.")
|
||||||
|
|
||||||
|
def decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta,
|
||||||
|
beam_size, cutoff_prob, cutoff_top_n,
|
||||||
|
vocab_list, num_processes):
|
||||||
|
"""Decode by beam search for a batch of probs matrix input.
|
||||||
|
|
||||||
|
:param probs_split: List of 2-D probability matrix, and each consists
|
||||||
|
of prob vectors for one speech utterancce.
|
||||||
|
:param probs_split: List of matrix
|
||||||
|
:param beam_alpha: Parameter associated with language model.
|
||||||
|
:type beam_alpha: float
|
||||||
|
:param beam_beta: Parameter associated with word count.
|
||||||
|
:type beam_beta: float
|
||||||
|
:param beam_size: Width for Beam search.
|
||||||
|
:type beam_size: int
|
||||||
|
:param cutoff_prob: Cutoff probability in pruning,
|
||||||
|
default 1.0, no pruning.
|
||||||
|
:type cutoff_prob: float
|
||||||
|
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
|
||||||
|
characters with highest probs in vocabulary will be
|
||||||
|
used in beam search, default 40.
|
||||||
|
:type cutoff_top_n: int
|
||||||
|
:param vocab_list: List of tokens in the vocabulary, for decoding.
|
||||||
|
:type vocab_list: list
|
||||||
|
:param num_processes: Number of processes (CPU) for decoder.
|
||||||
|
:type num_processes: int
|
||||||
|
:return: List of transcription texts.
|
||||||
|
:rtype: List of basestring
|
||||||
|
"""
|
||||||
|
if self._ext_scorer != None:
|
||||||
|
self._ext_scorer.reset_params(beam_alpha, beam_beta)
|
||||||
|
# beam search decode
|
||||||
|
num_processes = min(num_processes, len(probs_split))
|
||||||
|
beam_search_results = ctc_beam_search_decoder_batch(
|
||||||
|
probs_split=probs_split,
|
||||||
|
vocabulary=vocab_list,
|
||||||
|
beam_size=beam_size,
|
||||||
|
num_processes=num_processes,
|
||||||
|
ext_scoring_func=self._ext_scorer,
|
||||||
|
cutoff_prob=cutoff_prob,
|
||||||
|
cutoff_top_n=cutoff_top_n)
|
||||||
|
|
||||||
|
results = [result[0][1] for result in beam_search_results]
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _adapt_feeding_dict(self, feeding_dict):
|
||||||
|
"""Adapt feeding dict according to network struct.
|
||||||
|
|
||||||
|
To remove impacts from padding part, we add scale_sub_region layer and
|
||||||
|
sub_seq layer. For sub_seq layer, 'sequence_offset' and
|
||||||
|
'sequence_length' fields are appended. For each scale_sub_region layer
|
||||||
|
'convN_index_range' field is appended.
|
||||||
|
|
||||||
|
:param feeding_dict: Feeding is a map of field name and tuple index
|
||||||
|
of the data that reader returns.
|
||||||
|
:type feeding_dict: dict|list
|
||||||
|
:return: Adapted feeding dict.
|
||||||
|
:rtype: dict|list
|
||||||
|
"""
|
||||||
|
adapted_feeding_dict = copy.deepcopy(feeding_dict)
|
||||||
|
if isinstance(feeding_dict, dict):
|
||||||
|
adapted_feeding_dict["sequence_offset"] = len(adapted_feeding_dict)
|
||||||
|
adapted_feeding_dict["sequence_length"] = len(adapted_feeding_dict)
|
||||||
|
for i in xrange(self._num_conv_layers):
|
||||||
|
adapted_feeding_dict["conv%d_index_range" %i] = \
|
||||||
|
len(adapted_feeding_dict)
|
||||||
|
elif isinstance(feeding_dict, list):
|
||||||
|
adapted_feeding_dict.append("sequence_offset")
|
||||||
|
adapted_feeding_dict.append("sequence_length")
|
||||||
|
for i in xrange(self._num_conv_layers):
|
||||||
|
adapted_feeding_dict.append("conv%d_index_range" % i)
|
||||||
|
else:
|
||||||
|
raise ValueError("Type of feeding_dict is %s, not supported." %
|
||||||
|
type(feeding_dict))
|
||||||
|
|
||||||
|
return adapted_feeding_dict
|
||||||
|
|
||||||
|
def _adapt_data(self, data):
|
||||||
|
"""Adapt data according to network struct.
|
||||||
|
|
||||||
|
For each convolution layer in the conv_group, to remove impacts from
|
||||||
|
padding data, we can multiply zero to the padding part of the outputs
|
||||||
|
of each batch normalization layer. We add a scale_sub_region layer after
|
||||||
|
each batch normalization layer to reset the padding data.
|
||||||
|
For rnn layers, to remove impacts from padding data, we can truncate the
|
||||||
|
padding part before output data feeded into the first rnn layer. We use
|
||||||
|
sub_seq layer to achieve this.
|
||||||
|
|
||||||
|
:param data: Data from data_provider.
|
||||||
|
:type data: list|function
|
||||||
|
:return: Adapted data.
|
||||||
|
:rtype: list|function
|
||||||
|
"""
|
||||||
|
|
||||||
|
def adapt_instance(instance):
|
||||||
|
if len(instance) < 2 or len(instance) > 3:
|
||||||
|
raise ValueError("Size of instance should be 2 or 3.")
|
||||||
|
padded_audio = instance[0]
|
||||||
|
text = instance[1]
|
||||||
|
# no padding part
|
||||||
|
if len(instance) == 2:
|
||||||
|
audio_len = padded_audio.shape[1]
|
||||||
|
else:
|
||||||
|
audio_len = instance[2]
|
||||||
|
adapted_instance = [padded_audio, text]
|
||||||
|
# Stride size for conv0 is (3, 2)
|
||||||
|
# Stride size for conv1 to convN is (1, 2)
|
||||||
|
# Same as the network, hard-coded here
|
||||||
|
padded_conv0_h = (padded_audio.shape[0] - 1) // 2 + 1
|
||||||
|
padded_conv0_w = (padded_audio.shape[1] - 1) // 3 + 1
|
||||||
|
valid_w = (audio_len - 1) // 3 + 1
|
||||||
|
adapted_instance += [
|
||||||
|
[0], # sequence offset, always 0
|
||||||
|
[valid_w], # valid sequence length
|
||||||
|
# Index ranges for channel, height and width
|
||||||
|
# Please refer scale_sub_region layer to see details
|
||||||
|
[1, 32, 1, padded_conv0_h, valid_w + 1, padded_conv0_w]
|
||||||
|
]
|
||||||
|
pre_padded_h = padded_conv0_h
|
||||||
|
for i in xrange(self._num_conv_layers - 1):
|
||||||
|
padded_h = (pre_padded_h - 1) // 2 + 1
|
||||||
|
pre_padded_h = padded_h
|
||||||
|
adapted_instance += [
|
||||||
|
[1, 32, 1, padded_h, valid_w + 1, padded_conv0_w]
|
||||||
|
]
|
||||||
|
return adapted_instance
|
||||||
|
|
||||||
|
if isinstance(data, list):
|
||||||
|
return map(adapt_instance, data)
|
||||||
|
elif inspect.isgeneratorfunction(data):
|
||||||
|
|
||||||
|
def adapted_reader():
|
||||||
|
for instance in data():
|
||||||
|
yield map(adapt_instance, instance)
|
||||||
|
|
||||||
|
return adapted_reader
|
||||||
|
else:
|
||||||
|
raise ValueError("Type of data is %s, not supported." % type(data))
|
||||||
|
|
||||||
|
def _create_parameters(self, model_path=None):
|
||||||
|
"""Load or create model parameters."""
|
||||||
|
if model_path is None:
|
||||||
|
self._parameters = paddle.parameters.create(self._loss)
|
||||||
|
else:
|
||||||
|
self._parameters = paddle.parameters.Parameters.from_tar(
|
||||||
|
gzip.open(model_path))
|
||||||
|
|
||||||
|
def _create_network(self, vocab_size, num_conv_layers, num_rnn_layers,
|
||||||
|
rnn_layer_size, use_gru, share_rnn_weights):
|
||||||
|
"""Create data layers and model network."""
|
||||||
|
# paddle.data_type.dense_array is used for variable batch input.
|
||||||
|
# The size 161 * 161 is only an placeholder value and the real shape
|
||||||
|
# of input batch data will be induced during training.
|
||||||
|
audio_data = paddle.layer.data(
|
||||||
|
name="audio_spectrogram",
|
||||||
|
type=paddle.data_type.dense_array(161 * 161))
|
||||||
|
text_data = paddle.layer.data(
|
||||||
|
name="transcript_text",
|
||||||
|
type=paddle.data_type.integer_value_sequence(vocab_size))
|
||||||
|
seq_offset_data = paddle.layer.data(
|
||||||
|
name='sequence_offset',
|
||||||
|
type=paddle.data_type.integer_value_sequence(1))
|
||||||
|
seq_len_data = paddle.layer.data(
|
||||||
|
name='sequence_length',
|
||||||
|
type=paddle.data_type.integer_value_sequence(1))
|
||||||
|
index_range_datas = []
|
||||||
|
for i in xrange(num_rnn_layers):
|
||||||
|
index_range_datas.append(
|
||||||
|
paddle.layer.data(
|
||||||
|
name='conv%d_index_range' % i,
|
||||||
|
type=paddle.data_type.dense_vector(6)))
|
||||||
|
|
||||||
|
self._log_probs, self._loss = deep_speech_v2_network(
|
||||||
|
audio_data=audio_data,
|
||||||
|
text_data=text_data,
|
||||||
|
seq_offset_data=seq_offset_data,
|
||||||
|
seq_len_data=seq_len_data,
|
||||||
|
index_range_datas=index_range_datas,
|
||||||
|
dict_size=vocab_size,
|
||||||
|
num_conv_layers=num_conv_layers,
|
||||||
|
num_rnn_layers=num_rnn_layers,
|
||||||
|
rnn_size=rnn_layer_size,
|
||||||
|
use_gru=use_gru,
|
||||||
|
share_rnn_weights=share_rnn_weights)
|
@ -0,0 +1,302 @@
|
|||||||
|
"""Contains DeepSpeech2 layers and networks."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import paddle.v2 as paddle
|
||||||
|
|
||||||
|
|
||||||
|
def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride,
|
||||||
|
padding, act, index_range_data):
|
||||||
|
"""Convolution layer with batch normalization.
|
||||||
|
|
||||||
|
:param input: Input layer.
|
||||||
|
:type input: LayerOutput
|
||||||
|
:param filter_size: The x dimension of a filter kernel. Or input a tuple for
|
||||||
|
two image dimension.
|
||||||
|
:type filter_size: int|tuple|list
|
||||||
|
:param num_channels_in: Number of input channels.
|
||||||
|
:type num_channels_in: int
|
||||||
|
:type num_channels_out: Number of output channels.
|
||||||
|
:type num_channels_in: out
|
||||||
|
:param padding: The x dimension of the padding. Or input a tuple for two
|
||||||
|
image dimension.
|
||||||
|
:type padding: int|tuple|list
|
||||||
|
:param act: Activation type.
|
||||||
|
:type act: BaseActivation
|
||||||
|
:param index_range_data: Index range to indicate sub region.
|
||||||
|
:type index_range_data: LayerOutput
|
||||||
|
:return: Batch norm layer after convolution layer.
|
||||||
|
:rtype: LayerOutput
|
||||||
|
"""
|
||||||
|
conv_layer = paddle.layer.img_conv(
|
||||||
|
input=input,
|
||||||
|
filter_size=filter_size,
|
||||||
|
num_channels=num_channels_in,
|
||||||
|
num_filters=num_channels_out,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
act=paddle.activation.Linear(),
|
||||||
|
bias_attr=False)
|
||||||
|
batch_norm = paddle.layer.batch_norm(input=conv_layer, act=act)
|
||||||
|
# reset padding part to 0
|
||||||
|
scale_sub_region = paddle.layer.scale_sub_region(
|
||||||
|
batch_norm, index_range_data, value=0.0)
|
||||||
|
return scale_sub_region
|
||||||
|
|
||||||
|
|
||||||
|
def bidirectional_simple_rnn_bn_layer(name, input, size, act, share_weights):
|
||||||
|
"""Bidirectonal simple rnn layer with sequence-wise batch normalization.
|
||||||
|
The batch normalization is only performed on input-state weights.
|
||||||
|
|
||||||
|
:param name: Name of the layer.
|
||||||
|
:type name: string
|
||||||
|
:param input: Input layer.
|
||||||
|
:type input: LayerOutput
|
||||||
|
:param size: Number of RNN cells.
|
||||||
|
:type size: int
|
||||||
|
:param act: Activation type.
|
||||||
|
:type act: BaseActivation
|
||||||
|
:param share_weights: Whether to share input-hidden weights between
|
||||||
|
forward and backward directional RNNs.
|
||||||
|
:type share_weights: bool
|
||||||
|
:return: Bidirectional simple rnn layer.
|
||||||
|
:rtype: LayerOutput
|
||||||
|
"""
|
||||||
|
if share_weights:
|
||||||
|
# input-hidden weights shared between bi-direcitonal rnn.
|
||||||
|
input_proj = paddle.layer.fc(
|
||||||
|
input=input,
|
||||||
|
size=size,
|
||||||
|
act=paddle.activation.Linear(),
|
||||||
|
bias_attr=False)
|
||||||
|
# batch norm is only performed on input-state projection
|
||||||
|
input_proj_bn = paddle.layer.batch_norm(
|
||||||
|
input=input_proj, act=paddle.activation.Linear())
|
||||||
|
# forward and backward in time
|
||||||
|
forward_simple_rnn = paddle.layer.recurrent(
|
||||||
|
input=input_proj_bn, act=act, reverse=False)
|
||||||
|
backward_simple_rnn = paddle.layer.recurrent(
|
||||||
|
input=input_proj_bn, act=act, reverse=True)
|
||||||
|
|
||||||
|
else:
|
||||||
|
input_proj_forward = paddle.layer.fc(
|
||||||
|
input=input,
|
||||||
|
size=size,
|
||||||
|
act=paddle.activation.Linear(),
|
||||||
|
bias_attr=False)
|
||||||
|
input_proj_backward = paddle.layer.fc(
|
||||||
|
input=input,
|
||||||
|
size=size,
|
||||||
|
act=paddle.activation.Linear(),
|
||||||
|
bias_attr=False)
|
||||||
|
# batch norm is only performed on input-state projection
|
||||||
|
input_proj_bn_forward = paddle.layer.batch_norm(
|
||||||
|
input=input_proj_forward, act=paddle.activation.Linear())
|
||||||
|
input_proj_bn_backward = paddle.layer.batch_norm(
|
||||||
|
input=input_proj_backward, act=paddle.activation.Linear())
|
||||||
|
# forward and backward in time
|
||||||
|
forward_simple_rnn = paddle.layer.recurrent(
|
||||||
|
input=input_proj_bn_forward, act=act, reverse=False)
|
||||||
|
backward_simple_rnn = paddle.layer.recurrent(
|
||||||
|
input=input_proj_bn_backward, act=act, reverse=True)
|
||||||
|
|
||||||
|
return paddle.layer.concat(input=[forward_simple_rnn, backward_simple_rnn])
|
||||||
|
|
||||||
|
|
||||||
|
def bidirectional_gru_bn_layer(name, input, size, act):
|
||||||
|
"""Bidirectonal gru layer with sequence-wise batch normalization.
|
||||||
|
The batch normalization is only performed on input-state weights.
|
||||||
|
|
||||||
|
:param name: Name of the layer.
|
||||||
|
:type name: string
|
||||||
|
:param input: Input layer.
|
||||||
|
:type input: LayerOutput
|
||||||
|
:param size: Number of RNN cells.
|
||||||
|
:type size: int
|
||||||
|
:param act: Activation type.
|
||||||
|
:type act: BaseActivation
|
||||||
|
:return: Bidirectional simple rnn layer.
|
||||||
|
:rtype: LayerOutput
|
||||||
|
"""
|
||||||
|
input_proj_forward = paddle.layer.fc(
|
||||||
|
input=input,
|
||||||
|
size=size * 3,
|
||||||
|
act=paddle.activation.Linear(),
|
||||||
|
bias_attr=False)
|
||||||
|
input_proj_backward = paddle.layer.fc(
|
||||||
|
input=input,
|
||||||
|
size=size * 3,
|
||||||
|
act=paddle.activation.Linear(),
|
||||||
|
bias_attr=False)
|
||||||
|
# batch norm is only performed on input-related projections
|
||||||
|
input_proj_bn_forward = paddle.layer.batch_norm(
|
||||||
|
input=input_proj_forward, act=paddle.activation.Linear())
|
||||||
|
input_proj_bn_backward = paddle.layer.batch_norm(
|
||||||
|
input=input_proj_backward, act=paddle.activation.Linear())
|
||||||
|
# forward and backward in time
|
||||||
|
forward_gru = paddle.layer.grumemory(
|
||||||
|
input=input_proj_bn_forward, act=act, reverse=False)
|
||||||
|
backward_gru = paddle.layer.grumemory(
|
||||||
|
input=input_proj_bn_backward, act=act, reverse=True)
|
||||||
|
return paddle.layer.concat(input=[forward_gru, backward_gru])
|
||||||
|
|
||||||
|
|
||||||
|
def conv_group(input, num_stacks, index_range_datas):
|
||||||
|
"""Convolution group with stacked convolution layers.
|
||||||
|
|
||||||
|
:param input: Input layer.
|
||||||
|
:type input: LayerOutput
|
||||||
|
:param num_stacks: Number of stacked convolution layers.
|
||||||
|
:type num_stacks: int
|
||||||
|
:param index_range_datas: Index ranges for each convolution layer.
|
||||||
|
:type index_range_datas: tuple|list
|
||||||
|
:return: Output layer of the convolution group.
|
||||||
|
:rtype: LayerOutput
|
||||||
|
"""
|
||||||
|
conv = conv_bn_layer(
|
||||||
|
input=input,
|
||||||
|
filter_size=(11, 41),
|
||||||
|
num_channels_in=1,
|
||||||
|
num_channels_out=32,
|
||||||
|
stride=(3, 2),
|
||||||
|
padding=(5, 20),
|
||||||
|
act=paddle.activation.BRelu(),
|
||||||
|
index_range_data=index_range_datas[0])
|
||||||
|
for i in xrange(num_stacks - 1):
|
||||||
|
conv = conv_bn_layer(
|
||||||
|
input=conv,
|
||||||
|
filter_size=(11, 21),
|
||||||
|
num_channels_in=32,
|
||||||
|
num_channels_out=32,
|
||||||
|
stride=(1, 2),
|
||||||
|
padding=(5, 10),
|
||||||
|
act=paddle.activation.BRelu(),
|
||||||
|
index_range_data=index_range_datas[i + 1])
|
||||||
|
output_num_channels = 32
|
||||||
|
output_height = 160 // pow(2, num_stacks) + 1
|
||||||
|
return conv, output_num_channels, output_height
|
||||||
|
|
||||||
|
|
||||||
|
def rnn_group(input, size, num_stacks, use_gru, share_rnn_weights):
|
||||||
|
"""RNN group with stacked bidirectional simple RNN layers.
|
||||||
|
|
||||||
|
:param input: Input layer.
|
||||||
|
:type input: LayerOutput
|
||||||
|
:param size: Number of RNN cells in each layer.
|
||||||
|
:type size: int
|
||||||
|
:param num_stacks: Number of stacked rnn layers.
|
||||||
|
:type num_stacks: int
|
||||||
|
:param use_gru: Use gru if set True. Use simple rnn if set False.
|
||||||
|
:type use_gru: bool
|
||||||
|
:param share_rnn_weights: Whether to share input-hidden weights between
|
||||||
|
forward and backward directional RNNs.
|
||||||
|
It is only available when use_gru=False.
|
||||||
|
:type share_weights: bool
|
||||||
|
:return: Output layer of the RNN group.
|
||||||
|
:rtype: LayerOutput
|
||||||
|
"""
|
||||||
|
output = input
|
||||||
|
for i in xrange(num_stacks):
|
||||||
|
if use_gru:
|
||||||
|
output = bidirectional_gru_bn_layer(
|
||||||
|
name=str(i),
|
||||||
|
input=output,
|
||||||
|
size=size,
|
||||||
|
act=paddle.activation.Relu())
|
||||||
|
# BRelu does not support hppl, need to add later. Use Relu instead.
|
||||||
|
else:
|
||||||
|
output = bidirectional_simple_rnn_bn_layer(
|
||||||
|
name=str(i),
|
||||||
|
input=output,
|
||||||
|
size=size,
|
||||||
|
act=paddle.activation.BRelu(),
|
||||||
|
share_weights=share_rnn_weights)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def deep_speech_v2_network(audio_data,
|
||||||
|
text_data,
|
||||||
|
seq_offset_data,
|
||||||
|
seq_len_data,
|
||||||
|
index_range_datas,
|
||||||
|
dict_size,
|
||||||
|
num_conv_layers=2,
|
||||||
|
num_rnn_layers=3,
|
||||||
|
rnn_size=256,
|
||||||
|
use_gru=False,
|
||||||
|
share_rnn_weights=True):
|
||||||
|
"""The DeepSpeech2 network structure.
|
||||||
|
|
||||||
|
:param audio_data: Audio spectrogram data layer.
|
||||||
|
:type audio_data: LayerOutput
|
||||||
|
:param text_data: Transcription text data layer.
|
||||||
|
:type text_data: LayerOutput
|
||||||
|
:param seq_offset_data: Sequence offset data layer.
|
||||||
|
:type seq_offset_data: LayerOutput
|
||||||
|
:param seq_len_data: Valid sequence length data layer.
|
||||||
|
:type seq_len_data: LayerOutput
|
||||||
|
:param index_range_datas: Index ranges data layers.
|
||||||
|
:type index_range_datas: tuple|list
|
||||||
|
:param dict_size: Dictionary size for tokenized transcription.
|
||||||
|
:type dict_size: int
|
||||||
|
:param num_conv_layers: Number of stacking convolution layers.
|
||||||
|
:type num_conv_layers: int
|
||||||
|
:param num_rnn_layers: Number of stacking RNN layers.
|
||||||
|
:type num_rnn_layers: int
|
||||||
|
:param rnn_size: RNN layer size (number of RNN cells).
|
||||||
|
:type rnn_size: int
|
||||||
|
:param use_gru: Use gru if set True. Use simple rnn if set False.
|
||||||
|
:type use_gru: bool
|
||||||
|
:param share_rnn_weights: Whether to share input-hidden weights between
|
||||||
|
forward and backward direction RNNs.
|
||||||
|
It is only available when use_gru=False.
|
||||||
|
:type share_weights: bool
|
||||||
|
:return: A tuple of an output unnormalized log probability layer (
|
||||||
|
before softmax) and a ctc cost layer.
|
||||||
|
:rtype: tuple of LayerOutput
|
||||||
|
"""
|
||||||
|
# convolution group
|
||||||
|
conv_group_output, conv_group_num_channels, conv_group_height = conv_group(
|
||||||
|
input=audio_data,
|
||||||
|
num_stacks=num_conv_layers,
|
||||||
|
index_range_datas=index_range_datas)
|
||||||
|
# convert data form convolution feature map to sequence of vectors
|
||||||
|
conv2seq = paddle.layer.block_expand(
|
||||||
|
input=conv_group_output,
|
||||||
|
num_channels=conv_group_num_channels,
|
||||||
|
stride_x=1,
|
||||||
|
stride_y=1,
|
||||||
|
block_x=1,
|
||||||
|
block_y=conv_group_height)
|
||||||
|
# remove padding part
|
||||||
|
remove_padding_data = paddle.layer.sub_seq(
|
||||||
|
input=conv2seq,
|
||||||
|
offsets=seq_offset_data,
|
||||||
|
sizes=seq_len_data,
|
||||||
|
act=paddle.activation.Linear(),
|
||||||
|
bias_attr=False)
|
||||||
|
# rnn group
|
||||||
|
rnn_group_output = rnn_group(
|
||||||
|
input=remove_padding_data,
|
||||||
|
size=rnn_size,
|
||||||
|
num_stacks=num_rnn_layers,
|
||||||
|
use_gru=use_gru,
|
||||||
|
share_rnn_weights=share_rnn_weights)
|
||||||
|
fc = paddle.layer.fc(
|
||||||
|
input=rnn_group_output,
|
||||||
|
size=dict_size + 1,
|
||||||
|
act=paddle.activation.Linear(),
|
||||||
|
bias_attr=True)
|
||||||
|
# probability distribution with softmax
|
||||||
|
log_probs = paddle.layer.mixed(
|
||||||
|
input=paddle.layer.identity_projection(input=fc),
|
||||||
|
act=paddle.activation.Softmax())
|
||||||
|
# ctc cost
|
||||||
|
ctc_loss = paddle.layer.warp_ctc(
|
||||||
|
input=fc,
|
||||||
|
label=text_data,
|
||||||
|
size=dict_size + 1,
|
||||||
|
blank=dict_size,
|
||||||
|
norm_by_times=True)
|
||||||
|
return log_probs, ctc_loss
|
@ -0,0 +1,19 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
. ../../utils/utility.sh
|
||||||
|
|
||||||
|
URL='https://deepspeech.bj.bcebos.com/mandarin_models/aishell_model.tar.gz'
|
||||||
|
MD5=0ee83aa15fba421e5de8fc66c8feb350
|
||||||
|
TARGET=./aishell_model.tar.gz
|
||||||
|
|
||||||
|
|
||||||
|
echo "Download Aishell model ..."
|
||||||
|
download $URL $MD5 $TARGET
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Fail to download Aishell model!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
tar -zxvf $TARGET
|
||||||
|
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,19 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
. ../../utils/utility.sh
|
||||||
|
|
||||||
|
URL='https://deepspeech.bj.bcebos.com/demo_models/baidu_en8k_model.tar.gz'
|
||||||
|
MD5=5fe7639e720d51b3c3bdf7a1470c6272
|
||||||
|
TARGET=./baidu_en8k_model.tar.gz
|
||||||
|
|
||||||
|
|
||||||
|
echo "Download BaiduEn8k model ..."
|
||||||
|
download $URL $MD5 $TARGET
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Fail to download BaiduEn8k model!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
tar -zxvf $TARGET
|
||||||
|
|
||||||
|
|
||||||
|
exit 0
|
@ -0,0 +1,19 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
. ../../utils/utility.sh
|
||||||
|
|
||||||
|
URL='https://deepspeech.bj.bcebos.com/eng_models/librispeech_model.tar.gz'
|
||||||
|
MD5=1f72d0c5591f453362f0caa09dd57618
|
||||||
|
TARGET=./librispeech_model.tar.gz
|
||||||
|
|
||||||
|
|
||||||
|
echo "Download LibriSpeech model ..."
|
||||||
|
download $URL $MD5 $TARGET
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Fail to download LibriSpeech model!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
tar -zxvf $TARGET
|
||||||
|
|
||||||
|
|
||||||
|
exit 0
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue