parent
4c7fefd4e3
commit
f0184352f5
@ -1,123 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Prepare Aishell mandarin dataset
|
|
||||||
|
|
||||||
Download, unpack and create manifest files.
|
|
||||||
Manifest file is a json-format file with each line containing the
|
|
||||||
meta data (i.e. audio filepath, transcript and audio duration)
|
|
||||||
of each audio file in the data set.
|
|
||||||
"""
|
|
||||||
import argparse
|
|
||||||
import codecs
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
import soundfile
|
|
||||||
from data_utils.utility import download
|
|
||||||
from data_utils.utility import unpack
|
|
||||||
|
|
||||||
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
|
|
||||||
|
|
||||||
URL_ROOT = 'http://www.openslr.org/resources/33'
|
|
||||||
URL_ROOT = 'https://openslr.magicdatatech.com/resources/33'
|
|
||||||
DATA_URL = URL_ROOT + '/data_aishell.tgz'
|
|
||||||
MD5_DATA = '2f494334227864a8a8fec932999db9d8'
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description=__doc__)
|
|
||||||
parser.add_argument(
|
|
||||||
"--target_dir",
|
|
||||||
default=DATA_HOME + "/Aishell",
|
|
||||||
type=str,
|
|
||||||
help="Directory to save the dataset. (default: %(default)s)")
|
|
||||||
parser.add_argument(
|
|
||||||
"--manifest_prefix",
|
|
||||||
default="manifest",
|
|
||||||
type=str,
|
|
||||||
help="Filepath prefix for output manifests. (default: %(default)s)")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def create_manifest(data_dir, manifest_path_prefix):
|
|
||||||
print("Creating manifest %s ..." % manifest_path_prefix)
|
|
||||||
json_lines = []
|
|
||||||
transcript_path = os.path.join(data_dir, 'transcript',
|
|
||||||
'aishell_transcript_v0.8.txt')
|
|
||||||
transcript_dict = {}
|
|
||||||
for line in codecs.open(transcript_path, 'r', 'utf-8'):
|
|
||||||
line = line.strip()
|
|
||||||
if line == '':
|
|
||||||
continue
|
|
||||||
audio_id, text = line.split(' ', 1)
|
|
||||||
# remove withespace
|
|
||||||
text = ''.join(text.split())
|
|
||||||
transcript_dict[audio_id] = text
|
|
||||||
|
|
||||||
data_types = ['train', 'dev', 'test']
|
|
||||||
for type in data_types:
|
|
||||||
del json_lines[:]
|
|
||||||
audio_dir = os.path.join(data_dir, 'wav', type)
|
|
||||||
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
|
|
||||||
for fname in filelist:
|
|
||||||
audio_path = os.path.join(subfolder, fname)
|
|
||||||
audio_id = fname[:-4]
|
|
||||||
# if no transcription for audio then skipped
|
|
||||||
if audio_id not in transcript_dict:
|
|
||||||
continue
|
|
||||||
audio_data, samplerate = soundfile.read(audio_path)
|
|
||||||
duration = float(len(audio_data) / samplerate)
|
|
||||||
text = transcript_dict[audio_id]
|
|
||||||
json_lines.append(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
'audio_filepath': audio_path,
|
|
||||||
'duration': duration,
|
|
||||||
'text': text
|
|
||||||
},
|
|
||||||
ensure_ascii=False))
|
|
||||||
manifest_path = manifest_path_prefix + '.' + type
|
|
||||||
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
|
|
||||||
for line in json_lines:
|
|
||||||
fout.write(line + '\n')
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_dataset(url, md5sum, target_dir, manifest_path):
|
|
||||||
"""Download, unpack and create manifest file."""
|
|
||||||
data_dir = os.path.join(target_dir, 'data_aishell')
|
|
||||||
if not os.path.exists(data_dir):
|
|
||||||
filepath = download(url, md5sum, target_dir)
|
|
||||||
unpack(filepath, target_dir)
|
|
||||||
# unpack all audio tar files
|
|
||||||
audio_dir = os.path.join(data_dir, 'wav')
|
|
||||||
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
|
|
||||||
for ftar in filelist:
|
|
||||||
unpack(os.path.join(subfolder, ftar), subfolder, True)
|
|
||||||
else:
|
|
||||||
print("Skip downloading and unpacking. Data already exists in %s." %
|
|
||||||
target_dir)
|
|
||||||
create_manifest(data_dir, manifest_path)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
if args.target_dir.startswith('~'):
|
|
||||||
args.target_dir = os.path.expanduser(args.target_dir)
|
|
||||||
|
|
||||||
prepare_dataset(
|
|
||||||
url=DATA_URL,
|
|
||||||
md5sum=MD5_DATA,
|
|
||||||
target_dir=args.target_dir,
|
|
||||||
manifest_path=args.manifest_prefix)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
@ -1,159 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Prepare Librispeech ASR datasets.
|
|
||||||
|
|
||||||
Download, unpack and create manifest files.
|
|
||||||
Manifest file is a json-format file with each line containing the
|
|
||||||
meta data (i.e. audio filepath, transcript and audio duration)
|
|
||||||
of each audio file in the data set.
|
|
||||||
"""
|
|
||||||
import argparse
|
|
||||||
import codecs
|
|
||||||
import distutils.util
|
|
||||||
import io
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
import soundfile
|
|
||||||
from data_utils.utility import download
|
|
||||||
from data_utils.utility import unpack
|
|
||||||
|
|
||||||
URL_ROOT = "http://www.openslr.org/resources/12"
|
|
||||||
URL_ROOT = "https://openslr.magicdatatech.com/resources/12"
|
|
||||||
URL_TEST_CLEAN = URL_ROOT + "/test-clean.tar.gz"
|
|
||||||
URL_TEST_OTHER = URL_ROOT + "/test-other.tar.gz"
|
|
||||||
URL_DEV_CLEAN = URL_ROOT + "/dev-clean.tar.gz"
|
|
||||||
URL_DEV_OTHER = URL_ROOT + "/dev-other.tar.gz"
|
|
||||||
URL_TRAIN_CLEAN_100 = URL_ROOT + "/train-clean-100.tar.gz"
|
|
||||||
URL_TRAIN_CLEAN_360 = URL_ROOT + "/train-clean-360.tar.gz"
|
|
||||||
URL_TRAIN_OTHER_500 = URL_ROOT + "/train-other-500.tar.gz"
|
|
||||||
|
|
||||||
MD5_TEST_CLEAN = "32fa31d27d2e1cad72775fee3f4849a9"
|
|
||||||
MD5_TEST_OTHER = "fb5a50374b501bb3bac4815ee91d3135"
|
|
||||||
MD5_DEV_CLEAN = "42e2234ba48799c1f50f24a7926300a1"
|
|
||||||
MD5_DEV_OTHER = "c8d0bcc9cca99d4f8b62fcc847357931"
|
|
||||||
MD5_TRAIN_CLEAN_100 = "2a93770f6d5c6c964bc36631d331a522"
|
|
||||||
MD5_TRAIN_CLEAN_360 = "c0e676e450a7ff2f54aeade5171606fa"
|
|
||||||
MD5_TRAIN_OTHER_500 = "d1a0fd59409feb2c614ce4d30c387708"
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description=__doc__)
|
|
||||||
parser.add_argument(
|
|
||||||
"--target_dir",
|
|
||||||
default='~/.cache/paddle/dataset/speech/libri',
|
|
||||||
type=str,
|
|
||||||
help="Directory to save the dataset. (default: %(default)s)")
|
|
||||||
parser.add_argument(
|
|
||||||
"--manifest_prefix",
|
|
||||||
default="manifest",
|
|
||||||
type=str,
|
|
||||||
help="Filepath prefix for output manifests. (default: %(default)s)")
|
|
||||||
parser.add_argument(
|
|
||||||
"--full_download",
|
|
||||||
default="True",
|
|
||||||
type=distutils.util.strtobool,
|
|
||||||
help="Download all datasets for Librispeech."
|
|
||||||
" If False, only download a minimal requirement (test-clean, dev-clean"
|
|
||||||
" train-clean-100). (default: %(default)s)")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def create_manifest(data_dir, manifest_path):
|
|
||||||
"""Create a manifest json file summarizing the data set, with each line
|
|
||||||
containing the meta data (i.e. audio filepath, transcription text, audio
|
|
||||||
duration) of each audio file within the data set.
|
|
||||||
"""
|
|
||||||
print("Creating manifest %s ..." % manifest_path)
|
|
||||||
json_lines = []
|
|
||||||
for subfolder, _, filelist in sorted(os.walk(data_dir)):
|
|
||||||
text_filelist = [
|
|
||||||
filename for filename in filelist if filename.endswith('trans.txt')
|
|
||||||
]
|
|
||||||
if len(text_filelist) > 0:
|
|
||||||
text_filepath = os.path.join(subfolder, text_filelist[0])
|
|
||||||
for line in io.open(text_filepath, encoding="utf8"):
|
|
||||||
segments = line.strip().split()
|
|
||||||
text = ' '.join(segments[1:]).lower()
|
|
||||||
audio_filepath = os.path.join(subfolder, segments[0] + '.flac')
|
|
||||||
audio_data, samplerate = soundfile.read(audio_filepath)
|
|
||||||
duration = float(len(audio_data)) / samplerate
|
|
||||||
json_lines.append(
|
|
||||||
json.dumps({
|
|
||||||
'audio_filepath': audio_filepath,
|
|
||||||
'duration': duration,
|
|
||||||
'text': text
|
|
||||||
}))
|
|
||||||
with codecs.open(manifest_path, 'w', 'utf-8') as out_file:
|
|
||||||
for line in json_lines:
|
|
||||||
out_file.write(line + '\n')
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_dataset(url, md5sum, target_dir, manifest_path):
|
|
||||||
"""Download, unpack and create summmary manifest file.
|
|
||||||
"""
|
|
||||||
if not os.path.exists(os.path.join(target_dir, "LibriSpeech")):
|
|
||||||
# download
|
|
||||||
filepath = download(url, md5sum, target_dir)
|
|
||||||
# unpack
|
|
||||||
unpack(filepath, target_dir)
|
|
||||||
else:
|
|
||||||
print("Skip downloading and unpacking. Data already exists in %s." %
|
|
||||||
target_dir)
|
|
||||||
# create manifest json file
|
|
||||||
create_manifest(target_dir, manifest_path)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
if args.target_dir.startswith('~'):
|
|
||||||
args.target_dir = os.path.expanduser(args.target_dir)
|
|
||||||
|
|
||||||
prepare_dataset(
|
|
||||||
url=URL_TEST_CLEAN,
|
|
||||||
md5sum=MD5_TEST_CLEAN,
|
|
||||||
target_dir=os.path.join(args.target_dir, "test-clean"),
|
|
||||||
manifest_path=args.manifest_prefix + ".test-clean")
|
|
||||||
prepare_dataset(
|
|
||||||
url=URL_DEV_CLEAN,
|
|
||||||
md5sum=MD5_DEV_CLEAN,
|
|
||||||
target_dir=os.path.join(args.target_dir, "dev-clean"),
|
|
||||||
manifest_path=args.manifest_prefix + ".dev-clean")
|
|
||||||
if args.full_download:
|
|
||||||
prepare_dataset(
|
|
||||||
url=URL_TRAIN_CLEAN_100,
|
|
||||||
md5sum=MD5_TRAIN_CLEAN_100,
|
|
||||||
target_dir=os.path.join(args.target_dir, "train-clean-100"),
|
|
||||||
manifest_path=args.manifest_prefix + ".train-clean-100")
|
|
||||||
prepare_dataset(
|
|
||||||
url=URL_TEST_OTHER,
|
|
||||||
md5sum=MD5_TEST_OTHER,
|
|
||||||
target_dir=os.path.join(args.target_dir, "test-other"),
|
|
||||||
manifest_path=args.manifest_prefix + ".test-other")
|
|
||||||
prepare_dataset(
|
|
||||||
url=URL_DEV_OTHER,
|
|
||||||
md5sum=MD5_DEV_OTHER,
|
|
||||||
target_dir=os.path.join(args.target_dir, "dev-other"),
|
|
||||||
manifest_path=args.manifest_prefix + ".dev-other")
|
|
||||||
prepare_dataset(
|
|
||||||
url=URL_TRAIN_CLEAN_360,
|
|
||||||
md5sum=MD5_TRAIN_CLEAN_360,
|
|
||||||
target_dir=os.path.join(args.target_dir, "train-clean-360"),
|
|
||||||
manifest_path=args.manifest_prefix + ".train-clean-360")
|
|
||||||
prepare_dataset(
|
|
||||||
url=URL_TRAIN_OTHER_500,
|
|
||||||
md5sum=MD5_TRAIN_OTHER_500,
|
|
||||||
target_dir=os.path.join(args.target_dir, "train-other-500"),
|
|
||||||
manifest_path=args.manifest_prefix + ".train-other-500")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
@ -1,139 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Prepare CHiME3 background data.
|
|
||||||
|
|
||||||
Download, unpack and create manifest files.
|
|
||||||
Manifest file is a json-format file with each line containing the
|
|
||||||
meta data (i.e. audio filepath, transcript and audio duration)
|
|
||||||
of each audio file in the data set.
|
|
||||||
"""
|
|
||||||
import argparse
|
|
||||||
import io
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import zipfile
|
|
||||||
|
|
||||||
import soundfile
|
|
||||||
import wget
|
|
||||||
from paddle.v2.dataset.common import md5file
|
|
||||||
|
|
||||||
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
|
|
||||||
|
|
||||||
URL = "https://d4s.myairbridge.com/packagev2/AG0Y3DNBE5IWRRTV/?dlid=W19XG7T0NNHB027139H0EQ"
|
|
||||||
MD5 = "c3ff512618d7a67d4f85566ea1bc39ec"
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description=__doc__)
|
|
||||||
parser.add_argument(
|
|
||||||
"--target_dir",
|
|
||||||
default=DATA_HOME + "/chime3_background",
|
|
||||||
type=str,
|
|
||||||
help="Directory to save the dataset. (default: %(default)s)")
|
|
||||||
parser.add_argument(
|
|
||||||
"--manifest_filepath",
|
|
||||||
default="manifest.chime3.background",
|
|
||||||
type=str,
|
|
||||||
help="Filepath for output manifests. (default: %(default)s)")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def download(url, md5sum, target_dir, filename=None):
|
|
||||||
"""Download file from url to target_dir, and check md5sum."""
|
|
||||||
if filename is None:
|
|
||||||
filename = url.split("/")[-1]
|
|
||||||
if not os.path.exists(target_dir):
|
|
||||||
os.makedirs(target_dir)
|
|
||||||
filepath = os.path.join(target_dir, filename)
|
|
||||||
if not (os.path.exists(filepath) and md5file(filepath) == md5sum):
|
|
||||||
print("Downloading %s ..." % url)
|
|
||||||
wget.download(url, target_dir)
|
|
||||||
print("\nMD5 Chesksum %s ..." % filepath)
|
|
||||||
if not md5file(filepath) == md5sum:
|
|
||||||
raise RuntimeError("MD5 checksum failed.")
|
|
||||||
else:
|
|
||||||
print("File exists, skip downloading. (%s)" % filepath)
|
|
||||||
return filepath
|
|
||||||
|
|
||||||
|
|
||||||
def unpack(filepath, target_dir):
|
|
||||||
"""Unpack the file to the target_dir."""
|
|
||||||
print("Unpacking %s ..." % filepath)
|
|
||||||
if filepath.endswith('.zip'):
|
|
||||||
zip = zipfile.ZipFile(filepath, 'r')
|
|
||||||
zip.extractall(target_dir)
|
|
||||||
zip.close()
|
|
||||||
elif filepath.endswith('.tar') or filepath.endswith('.tar.gz'):
|
|
||||||
tar = zipfile.open(filepath)
|
|
||||||
tar.extractall(target_dir)
|
|
||||||
tar.close()
|
|
||||||
else:
|
|
||||||
raise ValueError("File format is not supported for unpacking.")
|
|
||||||
|
|
||||||
|
|
||||||
def create_manifest(data_dir, manifest_path):
|
|
||||||
"""Create a manifest json file summarizing the data set, with each line
|
|
||||||
containing the meta data (i.e. audio filepath, transcription text, audio
|
|
||||||
duration) of each audio file within the data set.
|
|
||||||
"""
|
|
||||||
print("Creating manifest %s ..." % manifest_path)
|
|
||||||
json_lines = []
|
|
||||||
for subfolder, _, filelist in sorted(os.walk(data_dir)):
|
|
||||||
for filename in filelist:
|
|
||||||
if filename.endswith('.wav'):
|
|
||||||
filepath = os.path.join(data_dir, subfolder, filename)
|
|
||||||
audio_data, samplerate = soundfile.read(filepath)
|
|
||||||
duration = float(len(audio_data)) / samplerate
|
|
||||||
json_lines.append(
|
|
||||||
json.dumps({
|
|
||||||
'audio_filepath': filepath,
|
|
||||||
'duration': duration,
|
|
||||||
'text': ''
|
|
||||||
}))
|
|
||||||
with io.open(manifest_path, mode='w', encoding='utf8') as out_file:
|
|
||||||
for line in json_lines:
|
|
||||||
out_file.write(line + '\n')
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_chime3(url, md5sum, target_dir, manifest_path):
|
|
||||||
"""Download, unpack and create summmary manifest file."""
|
|
||||||
if not os.path.exists(os.path.join(target_dir, "CHiME3")):
|
|
||||||
# download
|
|
||||||
filepath = download(url, md5sum, target_dir,
|
|
||||||
"myairbridge-AG0Y3DNBE5IWRRTV.zip")
|
|
||||||
# unpack
|
|
||||||
unpack(filepath, target_dir)
|
|
||||||
unpack(
|
|
||||||
os.path.join(target_dir, 'CHiME3_background_bus.zip'), target_dir)
|
|
||||||
unpack(
|
|
||||||
os.path.join(target_dir, 'CHiME3_background_caf.zip'), target_dir)
|
|
||||||
unpack(
|
|
||||||
os.path.join(target_dir, 'CHiME3_background_ped.zip'), target_dir)
|
|
||||||
unpack(
|
|
||||||
os.path.join(target_dir, 'CHiME3_background_str.zip'), target_dir)
|
|
||||||
else:
|
|
||||||
print("Skip downloading and unpacking. Data already exists in %s." %
|
|
||||||
target_dir)
|
|
||||||
# create manifest json file
|
|
||||||
create_manifest(target_dir, manifest_path)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
prepare_chime3(
|
|
||||||
url=URL,
|
|
||||||
md5sum=MD5,
|
|
||||||
target_dir=args.target_dir,
|
|
||||||
manifest_path=args.manifest_filepath)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
@ -1,16 +0,0 @@
|
|||||||
#! /usr/bin/env bash
|
|
||||||
|
|
||||||
# download data, generate manifests
|
|
||||||
PYTHONPATH=../../:$PYTHONPATH python voxforge.py \
|
|
||||||
--manifest_prefix='./manifest' \
|
|
||||||
--target_dir='./dataset/VoxForge' \
|
|
||||||
--is_merge_dialect=True \
|
|
||||||
--dialects 'american' 'british' 'australian' 'european' 'irish' 'canadian' 'indian'
|
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "Prepare VoxForge failed. Terminated."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "VoxForge Data preparation done."
|
|
||||||
exit 0
|
|
||||||
@ -1,234 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Prepare VoxForge dataset
|
|
||||||
|
|
||||||
Download, unpack and create manifest files.
|
|
||||||
Manifest file is a json-format file with each line containing the
|
|
||||||
meta data (i.e. audio filepath, transcript and audio duration)
|
|
||||||
of each audio file in the data set.
|
|
||||||
"""
|
|
||||||
import argparse
|
|
||||||
import codecs
|
|
||||||
import datetime
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import shutil
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
import soundfile
|
|
||||||
from data_utils.utility import download_multi
|
|
||||||
from data_utils.utility import getfile_insensitive
|
|
||||||
from data_utils.utility import unpack
|
|
||||||
|
|
||||||
DATA_HOME = './dataset'
|
|
||||||
|
|
||||||
DATA_URL = 'http://www.repository.voxforge1.org/downloads/SpeechCorpus/Trunk/' \
|
|
||||||
'Audio/Main/16kHz_16bit'
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description=__doc__)
|
|
||||||
parser.add_argument(
|
|
||||||
"--target_dir",
|
|
||||||
default=DATA_HOME + "/VoxForge",
|
|
||||||
type=str,
|
|
||||||
help="Directory to save the dataset. (default: %(default)s)")
|
|
||||||
parser.add_argument(
|
|
||||||
"--dialects",
|
|
||||||
default=[
|
|
||||||
'american', 'british', 'australian', 'european', 'irish', 'canadian',
|
|
||||||
'indian'
|
|
||||||
],
|
|
||||||
nargs='+',
|
|
||||||
type=str,
|
|
||||||
help="Dialect types. (default: %(default)s)")
|
|
||||||
parser.add_argument(
|
|
||||||
"--is_merge_dialect",
|
|
||||||
default=True,
|
|
||||||
type=bool,
|
|
||||||
help="If set True, manifests of american dialect and canadian dialect will "
|
|
||||||
"be merged to american-canadian dialect; manifests of british "
|
|
||||||
"dialect, irish dialect and australian dialect will be merged to "
|
|
||||||
"commonwealth dialect. (default: %(default)s)")
|
|
||||||
parser.add_argument(
|
|
||||||
"--manifest_prefix",
|
|
||||||
default="manifest",
|
|
||||||
type=str,
|
|
||||||
help="Filepath prefix for output manifests. (default: %(default)s)")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def download_and_unpack(target_dir, url):
|
|
||||||
wget_args = '-q -l 1 -N -nd -c -e robots=off -A tgz -r -np'
|
|
||||||
tgz_dir = os.path.join(target_dir, 'tgz')
|
|
||||||
exit_code = download_multi(url, tgz_dir, wget_args)
|
|
||||||
if exit_code != 0:
|
|
||||||
print('Download tgz audio files failed with exit code %d.' % exit_code)
|
|
||||||
else:
|
|
||||||
print('Download done, start unpacking ...')
|
|
||||||
audio_dir = os.path.join(target_dir, 'audio')
|
|
||||||
for root, dirs, files in os.walk(tgz_dir):
|
|
||||||
for file in files:
|
|
||||||
print(file)
|
|
||||||
if file.endswith('.tgz'):
|
|
||||||
unpack(os.path.join(root, file), audio_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def select_dialects(target_dir, dialect_list):
|
|
||||||
"""Classify audio files by dialect."""
|
|
||||||
dialect_root_dir = os.path.join(target_dir, 'dialect')
|
|
||||||
if os.path.exists(dialect_root_dir):
|
|
||||||
shutil.rmtree(dialect_root_dir)
|
|
||||||
os.mkdir(dialect_root_dir)
|
|
||||||
audio_dir = os.path.abspath(os.path.join(target_dir, 'audio'))
|
|
||||||
for dialect in dialect_list:
|
|
||||||
# filter files by dialect
|
|
||||||
command = 'find %s -iwholename "*etc/readme*" -exec egrep -iHl \
|
|
||||||
"pronunciation dialect.*%s" {} \;' % (audio_dir, dialect)
|
|
||||||
p = subprocess.Popen(
|
|
||||||
command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, shell=True)
|
|
||||||
output, err = p.communicate()
|
|
||||||
dialect_dir = os.path.join(dialect_root_dir, dialect)
|
|
||||||
if os.path.exists(dialect_dir):
|
|
||||||
shutil.rmtree(dialect_dir)
|
|
||||||
os.mkdir(dialect_dir)
|
|
||||||
for path in output.splitlines():
|
|
||||||
src_dir = os.path.dirname(os.path.dirname(path))
|
|
||||||
link = os.path.basename(os.path.normpath(src_dir))
|
|
||||||
os.symlink(src_dir, os.path.join(dialect_dir, link))
|
|
||||||
|
|
||||||
|
|
||||||
def generate_manifest(data_dir, manifest_path):
|
|
||||||
json_lines = []
|
|
||||||
|
|
||||||
for path in os.listdir(data_dir):
|
|
||||||
audio_link = os.path.join(data_dir, path)
|
|
||||||
assert os.path.islink(
|
|
||||||
audio_link), '%s should be symbolic link.' % audio_link
|
|
||||||
actual_audio_dir = os.path.abspath(os.readlink(audio_link))
|
|
||||||
|
|
||||||
audio_type = ''
|
|
||||||
if os.path.isdir(os.path.join(actual_audio_dir, 'wav')):
|
|
||||||
audio_type = 'wav'
|
|
||||||
elif os.path.isdir(os.path.join(actual_audio_dir, 'flac')):
|
|
||||||
audio_type = 'flac'
|
|
||||||
else:
|
|
||||||
print('Unknown audio type, skipped processing %s.' %
|
|
||||||
actual_audio_dir)
|
|
||||||
continue
|
|
||||||
|
|
||||||
etc_dir = os.path.join(actual_audio_dir, 'etc')
|
|
||||||
prompts_file = os.path.join(etc_dir, 'PROMPTS')
|
|
||||||
if not os.path.isfile(prompts_file):
|
|
||||||
print('PROMPTS file missing, skip processing %s.' %
|
|
||||||
actual_audio_dir)
|
|
||||||
continue
|
|
||||||
|
|
||||||
readme_file = getfile_insensitive(os.path.join(etc_dir, 'README'))
|
|
||||||
if readme_file is None:
|
|
||||||
print('README file missing, skip processing %s.' % actual_audio_dir)
|
|
||||||
continue
|
|
||||||
|
|
||||||
for line in file(prompts_file):
|
|
||||||
u, trans = line.strip().split(None, 1)
|
|
||||||
u_parts = u.split('/')
|
|
||||||
|
|
||||||
# try to format the date time
|
|
||||||
try:
|
|
||||||
speaker, date, sfx = u_parts[-3].split('-')
|
|
||||||
obj = datetime.datetime.strptime(date, '%y.%m.%d')
|
|
||||||
formatted = obj.strftime('%Y%m%d')
|
|
||||||
u_parts[-3] = '-'.join([speaker, formatted, sfx])
|
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if len(u_parts) < 2:
|
|
||||||
u_parts = [audio_type] + u_parts
|
|
||||||
u_parts[-2] = audio_type
|
|
||||||
u_parts[-1] += '.' + audio_type
|
|
||||||
u = os.path.join(actual_audio_dir, '/'.join(u_parts[-2:]))
|
|
||||||
|
|
||||||
if not os.path.isfile(u):
|
|
||||||
print('Audio file missing, skip processing %s.' % u)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if os.stat(u).st_size == 0:
|
|
||||||
print('Empty audio file, skip processing %s.' % u)
|
|
||||||
continue
|
|
||||||
|
|
||||||
trans = trans.strip().replace('-', ' ')
|
|
||||||
if not trans.isupper() or \
|
|
||||||
not trans.strip().replace(' ', '').replace("'", "").isalpha():
|
|
||||||
print("Transcript not normalized properly, skip processing %s."
|
|
||||||
% u)
|
|
||||||
continue
|
|
||||||
|
|
||||||
audio_data, samplerate = soundfile.read(u)
|
|
||||||
duration = float(len(audio_data)) / samplerate
|
|
||||||
json_lines.append(
|
|
||||||
json.dumps({
|
|
||||||
'audio_filepath': u,
|
|
||||||
'duration': duration,
|
|
||||||
'text': trans.lower()
|
|
||||||
}))
|
|
||||||
|
|
||||||
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
|
|
||||||
for line in json_lines:
|
|
||||||
fout.write(line + '\n')
|
|
||||||
|
|
||||||
|
|
||||||
def merge_manifests(manifest_files, save_path):
|
|
||||||
lines = []
|
|
||||||
for manifest_file in manifest_files:
|
|
||||||
line = codecs.open(manifest_file, 'r', 'utf-8').readlines()
|
|
||||||
lines += line
|
|
||||||
|
|
||||||
with codecs.open(save_path, 'w', 'utf-8') as fout:
|
|
||||||
for line in lines:
|
|
||||||
fout.write(line)
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_dataset(url, dialects, target_dir, manifest_prefix, is_merge):
|
|
||||||
download_and_unpack(target_dir, url)
|
|
||||||
select_dialects(target_dir, dialects)
|
|
||||||
american_canadian_manifests = []
|
|
||||||
commonwealth_manifests = []
|
|
||||||
for dialect in dialects:
|
|
||||||
dialect_dir = os.path.join(target_dir, 'dialect', dialect)
|
|
||||||
manifest_fpath = manifest_prefix + '.' + dialect
|
|
||||||
if dialect == 'american' or dialect == 'canadian':
|
|
||||||
american_canadian_manifests.append(manifest_fpath)
|
|
||||||
if dialect == 'australian' \
|
|
||||||
or dialect == 'british' \
|
|
||||||
or dialect == 'irish':
|
|
||||||
commonwealth_manifests.append(manifest_fpath)
|
|
||||||
generate_manifest(dialect_dir, manifest_fpath)
|
|
||||||
|
|
||||||
if is_merge:
|
|
||||||
if len(american_canadian_manifests) > 0:
|
|
||||||
manifest_fpath = manifest_prefix + '.american-canadian'
|
|
||||||
merge_manifests(american_canadian_manifests, manifest_fpath)
|
|
||||||
if len(commonwealth_manifests) > 0:
|
|
||||||
manifest_fpath = manifest_prefix + '.commonwealth'
|
|
||||||
merge_manifests(commonwealth_manifests, manifest_fpath)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
if args.target_dir.startswith('~'):
|
|
||||||
args.target_dir = os.path.expanduser(args.target_dir)
|
|
||||||
|
|
||||||
prepare_dataset(DATA_URL, args.dialects, args.target_dir,
|
|
||||||
args.manifest_prefix, args.is_merge_dialect)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
@ -1,695 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains the audio segment class."""
|
|
||||||
import copy
|
|
||||||
import io
|
|
||||||
import random
|
|
||||||
import re
|
|
||||||
import struct
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import resampy
|
|
||||||
import soundfile
|
|
||||||
from scipy import signal
|
|
||||||
|
|
||||||
|
|
||||||
class AudioSegment(object):
|
|
||||||
"""Monaural audio segment abstraction.
|
|
||||||
|
|
||||||
:param samples: Audio samples [num_samples x num_channels].
|
|
||||||
:type samples: ndarray.float32
|
|
||||||
:param sample_rate: Audio sample rate.
|
|
||||||
:type sample_rate: int
|
|
||||||
:raises TypeError: If the sample data type is not float or int.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, samples, sample_rate):
|
|
||||||
"""Create audio segment from samples.
|
|
||||||
|
|
||||||
Samples are convert float32 internally, with int scaled to [-1, 1].
|
|
||||||
"""
|
|
||||||
self._samples = self._convert_samples_to_float32(samples)
|
|
||||||
self._sample_rate = sample_rate
|
|
||||||
if self._samples.ndim >= 2:
|
|
||||||
self._samples = np.mean(self._samples, 1)
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
"""Return whether two objects are equal."""
|
|
||||||
if type(other) is not type(self):
|
|
||||||
return False
|
|
||||||
if self._sample_rate != other._sample_rate:
|
|
||||||
return False
|
|
||||||
if self._samples.shape != other._samples.shape:
|
|
||||||
return False
|
|
||||||
if np.any(self.samples != other._samples):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def __ne__(self, other):
|
|
||||||
"""Return whether two objects are unequal."""
|
|
||||||
return not self.__eq__(other)
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
"""Return human-readable representation of segment."""
|
|
||||||
return ("%s: num_samples=%d, sample_rate=%d, duration=%.2fsec, "
|
|
||||||
"rms=%.2fdB" % (type(self), self.num_samples, self.sample_rate,
|
|
||||||
self.duration, self.rms_db))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_file(cls, file):
|
|
||||||
"""Create audio segment from audio file.
|
|
||||||
|
|
||||||
:param filepath: Filepath or file object to audio file.
|
|
||||||
:type filepath: str|file
|
|
||||||
:return: Audio segment instance.
|
|
||||||
:rtype: AudioSegment
|
|
||||||
"""
|
|
||||||
if isinstance(file, str) and re.findall(r".seqbin_\d+$", file):
|
|
||||||
return cls.from_sequence_file(file)
|
|
||||||
else:
|
|
||||||
samples, sample_rate = soundfile.read(file, dtype='float32')
|
|
||||||
return cls(samples, sample_rate)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def slice_from_file(cls, file, start=None, end=None):
|
|
||||||
"""Loads a small section of an audio without having to load
|
|
||||||
the entire file into the memory which can be incredibly wasteful.
|
|
||||||
|
|
||||||
:param file: Input audio filepath or file object.
|
|
||||||
:type file: str|file
|
|
||||||
:param start: Start time in seconds. If start is negative, it wraps
|
|
||||||
around from the end. If not provided, this function
|
|
||||||
reads from the very beginning.
|
|
||||||
:type start: float
|
|
||||||
:param end: End time in seconds. If end is negative, it wraps around
|
|
||||||
from the end. If not provided, the default behvaior is
|
|
||||||
to read to the end of the file.
|
|
||||||
:type end: float
|
|
||||||
:return: AudioSegment instance of the specified slice of the input
|
|
||||||
audio file.
|
|
||||||
:rtype: AudioSegment
|
|
||||||
:raise ValueError: If start or end is incorrectly set, e.g. out of
|
|
||||||
bounds in time.
|
|
||||||
"""
|
|
||||||
sndfile = soundfile.SoundFile(file)
|
|
||||||
sample_rate = sndfile.samplerate
|
|
||||||
duration = float(len(sndfile)) / sample_rate
|
|
||||||
start = 0. if start is None else start
|
|
||||||
end = duration if end is None else end
|
|
||||||
if start < 0.0:
|
|
||||||
start += duration
|
|
||||||
if end < 0.0:
|
|
||||||
end += duration
|
|
||||||
if start < 0.0:
|
|
||||||
raise ValueError("The slice start position (%f s) is out of "
|
|
||||||
"bounds." % start)
|
|
||||||
if end < 0.0:
|
|
||||||
raise ValueError("The slice end position (%f s) is out of bounds." %
|
|
||||||
end)
|
|
||||||
if start > end:
|
|
||||||
raise ValueError("The slice start position (%f s) is later than "
|
|
||||||
"the slice end position (%f s)." % (start, end))
|
|
||||||
if end > duration:
|
|
||||||
raise ValueError("The slice end position (%f s) is out of bounds "
|
|
||||||
"(> %f s)" % (end, duration))
|
|
||||||
start_frame = int(start * sample_rate)
|
|
||||||
end_frame = int(end * sample_rate)
|
|
||||||
sndfile.seek(start_frame)
|
|
||||||
data = sndfile.read(frames=end_frame - start_frame, dtype='float32')
|
|
||||||
return cls(data, sample_rate)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_sequence_file(cls, filepath):
|
|
||||||
"""Create audio segment from sequence file. Sequence file is a binary
|
|
||||||
file containing a collection of multiple audio files, with several
|
|
||||||
header bytes in the head indicating the offsets of each audio byte data
|
|
||||||
chunk.
|
|
||||||
|
|
||||||
The format is:
|
|
||||||
|
|
||||||
4 bytes (int, version),
|
|
||||||
4 bytes (int, num of utterance),
|
|
||||||
4 bytes (int, bytes per header),
|
|
||||||
[bytes_per_header*(num_utterance+1)] bytes (offsets for each audio),
|
|
||||||
audio_bytes_data_of_1st_utterance,
|
|
||||||
audio_bytes_data_of_2nd_utterance,
|
|
||||||
......
|
|
||||||
|
|
||||||
Sequence file name must end with ".seqbin". And the filename of the 5th
|
|
||||||
utterance's audio file in sequence file "xxx.seqbin" must be
|
|
||||||
"xxx.seqbin_5", with "5" indicating the utterance index within this
|
|
||||||
sequence file (starting from 1).
|
|
||||||
|
|
||||||
:param filepath: Filepath of sequence file.
|
|
||||||
:type filepath: str
|
|
||||||
:return: Audio segment instance.
|
|
||||||
:rtype: AudioSegment
|
|
||||||
"""
|
|
||||||
# parse filepath
|
|
||||||
matches = re.match(r"(.+\.seqbin)_(\d+)", filepath)
|
|
||||||
if matches is None:
|
|
||||||
raise IOError("File type of %s is not supported" % filepath)
|
|
||||||
filename = matches.group(1)
|
|
||||||
fileno = int(matches.group(2))
|
|
||||||
|
|
||||||
# read headers
|
|
||||||
f = io.open(filename, mode='rb', encoding='utf8')
|
|
||||||
version = f.read(4)
|
|
||||||
num_utterances = struct.unpack("i", f.read(4))[0]
|
|
||||||
bytes_per_header = struct.unpack("i", f.read(4))[0]
|
|
||||||
header_bytes = f.read(bytes_per_header * (num_utterances + 1))
|
|
||||||
header = [
|
|
||||||
struct.unpack("i", header_bytes[bytes_per_header * i:
|
|
||||||
bytes_per_header * (i + 1)])[0]
|
|
||||||
for i in range(num_utterances + 1)
|
|
||||||
]
|
|
||||||
|
|
||||||
# read audio bytes
|
|
||||||
f.seek(header[fileno - 1])
|
|
||||||
audio_bytes = f.read(header[fileno] - header[fileno - 1])
|
|
||||||
f.close()
|
|
||||||
|
|
||||||
# create audio segment
|
|
||||||
try:
|
|
||||||
return cls.from_bytes(audio_bytes)
|
|
||||||
except Exception as e:
|
|
||||||
samples = np.frombuffer(audio_bytes, dtype='int16')
|
|
||||||
return cls(samples=samples, sample_rate=8000)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_bytes(cls, bytes):
|
|
||||||
"""Create audio segment from a byte string containing audio samples.
|
|
||||||
|
|
||||||
:param bytes: Byte string containing audio samples.
|
|
||||||
:type bytes: str
|
|
||||||
:return: Audio segment instance.
|
|
||||||
:rtype: AudioSegment
|
|
||||||
"""
|
|
||||||
samples, sample_rate = soundfile.read(
|
|
||||||
io.BytesIO(bytes), dtype='float32')
|
|
||||||
return cls(samples, sample_rate)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def concatenate(cls, *segments):
|
|
||||||
"""Concatenate an arbitrary number of audio segments together.
|
|
||||||
|
|
||||||
:param *segments: Input audio segments to be concatenated.
|
|
||||||
:type *segments: tuple of AudioSegment
|
|
||||||
:return: Audio segment instance as concatenating results.
|
|
||||||
:rtype: AudioSegment
|
|
||||||
:raises ValueError: If the number of segments is zero, or if the
|
|
||||||
sample_rate of any segments does not match.
|
|
||||||
:raises TypeError: If any segment is not AudioSegment instance.
|
|
||||||
"""
|
|
||||||
# Perform basic sanity-checks.
|
|
||||||
if len(segments) == 0:
|
|
||||||
raise ValueError("No audio segments are given to concatenate.")
|
|
||||||
sample_rate = segments[0]._sample_rate
|
|
||||||
for seg in segments:
|
|
||||||
if sample_rate != seg._sample_rate:
|
|
||||||
raise ValueError("Can't concatenate segments with "
|
|
||||||
"different sample rates")
|
|
||||||
if type(seg) is not cls:
|
|
||||||
raise TypeError("Only audio segments of the same type "
|
|
||||||
"can be concatenated.")
|
|
||||||
samples = np.concatenate([seg.samples for seg in segments])
|
|
||||||
return cls(samples, sample_rate)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def make_silence(cls, duration, sample_rate):
|
|
||||||
"""Creates a silent audio segment of the given duration and sample rate.
|
|
||||||
|
|
||||||
:param duration: Length of silence in seconds.
|
|
||||||
:type duration: float
|
|
||||||
:param sample_rate: Sample rate.
|
|
||||||
:type sample_rate: float
|
|
||||||
:return: Silent AudioSegment instance of the given duration.
|
|
||||||
:rtype: AudioSegment
|
|
||||||
"""
|
|
||||||
samples = np.zeros(int(duration * sample_rate))
|
|
||||||
return cls(samples, sample_rate)
|
|
||||||
|
|
||||||
def to_wav_file(self, filepath, dtype='float32'):
|
|
||||||
"""Save audio segment to disk as wav file.
|
|
||||||
|
|
||||||
:param filepath: WAV filepath or file object to save the
|
|
||||||
audio segment.
|
|
||||||
:type filepath: str|file
|
|
||||||
:param dtype: Subtype for audio file. Options: 'int16', 'int32',
|
|
||||||
'float32', 'float64'. Default is 'float32'.
|
|
||||||
:type dtype: str
|
|
||||||
:raises TypeError: If dtype is not supported.
|
|
||||||
"""
|
|
||||||
samples = self._convert_samples_from_float32(self._samples, dtype)
|
|
||||||
subtype_map = {
|
|
||||||
'int16': 'PCM_16',
|
|
||||||
'int32': 'PCM_32',
|
|
||||||
'float32': 'FLOAT',
|
|
||||||
'float64': 'DOUBLE'
|
|
||||||
}
|
|
||||||
soundfile.write(
|
|
||||||
filepath,
|
|
||||||
samples,
|
|
||||||
self._sample_rate,
|
|
||||||
format='WAV',
|
|
||||||
subtype=subtype_map[dtype])
|
|
||||||
|
|
||||||
def superimpose(self, other):
|
|
||||||
"""Add samples from another segment to those of this segment
|
|
||||||
(sample-wise addition, not segment concatenation).
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param other: Segment containing samples to be added in.
|
|
||||||
:type other: AudioSegments
|
|
||||||
:raise TypeError: If type of two segments don't match.
|
|
||||||
:raise ValueError: If the sample rates of the two segments are not
|
|
||||||
equal, or if the lengths of segments don't match.
|
|
||||||
"""
|
|
||||||
if isinstance(other, type(self)):
|
|
||||||
raise TypeError("Cannot add segments of different types: %s "
|
|
||||||
"and %s." % (type(self), type(other)))
|
|
||||||
if self._sample_rate != other._sample_rate:
|
|
||||||
raise ValueError("Sample rates must match to add segments.")
|
|
||||||
if len(self._samples) != len(other._samples):
|
|
||||||
raise ValueError("Segment lengths must match to add segments.")
|
|
||||||
self._samples += other._samples
|
|
||||||
|
|
||||||
def to_bytes(self, dtype='float32'):
|
|
||||||
"""Create a byte string containing the audio content.
|
|
||||||
|
|
||||||
:param dtype: Data type for export samples. Options: 'int16', 'int32',
|
|
||||||
'float32', 'float64'. Default is 'float32'.
|
|
||||||
:type dtype: str
|
|
||||||
:return: Byte string containing audio content.
|
|
||||||
:rtype: str
|
|
||||||
"""
|
|
||||||
samples = self._convert_samples_from_float32(self._samples, dtype)
|
|
||||||
return samples.tostring()
|
|
||||||
|
|
||||||
def gain_db(self, gain):
|
|
||||||
"""Apply gain in decibels to samples.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param gain: Gain in decibels to apply to samples.
|
|
||||||
:type gain: float|1darray
|
|
||||||
"""
|
|
||||||
self._samples *= 10.**(gain / 20.)
|
|
||||||
|
|
||||||
def change_speed(self, speed_rate):
|
|
||||||
"""Change the audio speed by linear interpolation.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param speed_rate: Rate of speed change:
|
|
||||||
speed_rate > 1.0, speed up the audio;
|
|
||||||
speed_rate = 1.0, unchanged;
|
|
||||||
speed_rate < 1.0, slow down the audio;
|
|
||||||
speed_rate <= 0.0, not allowed, raise ValueError.
|
|
||||||
:type speed_rate: float
|
|
||||||
:raises ValueError: If speed_rate <= 0.0.
|
|
||||||
"""
|
|
||||||
if speed_rate <= 0:
|
|
||||||
raise ValueError("speed_rate should be greater than zero.")
|
|
||||||
old_length = self._samples.shape[0]
|
|
||||||
new_length = int(old_length / speed_rate)
|
|
||||||
old_indices = np.arange(old_length)
|
|
||||||
new_indices = np.linspace(start=0, stop=old_length, num=new_length)
|
|
||||||
self._samples = np.interp(new_indices, old_indices, self._samples)
|
|
||||||
|
|
||||||
def normalize(self, target_db=-20, max_gain_db=300.0):
|
|
||||||
"""Normalize audio to be of the desired RMS value in decibels.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param target_db: Target RMS value in decibels. This value should be
|
|
||||||
less than 0.0 as 0.0 is full-scale audio.
|
|
||||||
:type target_db: float
|
|
||||||
:param max_gain_db: Max amount of gain in dB that can be applied for
|
|
||||||
normalization. This is to prevent nans when
|
|
||||||
attempting to normalize a signal consisting of
|
|
||||||
all zeros.
|
|
||||||
:type max_gain_db: float
|
|
||||||
:raises ValueError: If the required gain to normalize the segment to
|
|
||||||
the target_db value exceeds max_gain_db.
|
|
||||||
"""
|
|
||||||
gain = target_db - self.rms_db
|
|
||||||
if gain > max_gain_db:
|
|
||||||
raise ValueError(
|
|
||||||
"Unable to normalize segment to %f dB because the "
|
|
||||||
"the probable gain have exceeds max_gain_db (%f dB)" %
|
|
||||||
(target_db, max_gain_db))
|
|
||||||
self.gain_db(min(max_gain_db, target_db - self.rms_db))
|
|
||||||
|
|
||||||
def normalize_online_bayesian(self,
|
|
||||||
target_db,
|
|
||||||
prior_db,
|
|
||||||
prior_samples,
|
|
||||||
startup_delay=0.0):
|
|
||||||
"""Normalize audio using a production-compatible online/causal
|
|
||||||
algorithm. This uses an exponential likelihood and gamma prior to
|
|
||||||
make online estimates of the RMS even when there are very few samples.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param target_db: Target RMS value in decibels.
|
|
||||||
:type target_bd: float
|
|
||||||
:param prior_db: Prior RMS estimate in decibels.
|
|
||||||
:type prior_db: float
|
|
||||||
:param prior_samples: Prior strength in number of samples.
|
|
||||||
:type prior_samples: float
|
|
||||||
:param startup_delay: Default 0.0s. If provided, this function will
|
|
||||||
accrue statistics for the first startup_delay
|
|
||||||
seconds before applying online normalization.
|
|
||||||
:type startup_delay: float
|
|
||||||
"""
|
|
||||||
# Estimate total RMS online.
|
|
||||||
startup_sample_idx = min(self.num_samples - 1,
|
|
||||||
int(self.sample_rate * startup_delay))
|
|
||||||
prior_mean_squared = 10.**(prior_db / 10.)
|
|
||||||
prior_sum_of_squares = prior_mean_squared * prior_samples
|
|
||||||
cumsum_of_squares = np.cumsum(self.samples**2)
|
|
||||||
sample_count = np.arange(self.num_samples) + 1
|
|
||||||
if startup_sample_idx > 0:
|
|
||||||
cumsum_of_squares[:startup_sample_idx] = \
|
|
||||||
cumsum_of_squares[startup_sample_idx]
|
|
||||||
sample_count[:startup_sample_idx] = \
|
|
||||||
sample_count[startup_sample_idx]
|
|
||||||
mean_squared_estimate = ((cumsum_of_squares + prior_sum_of_squares) /
|
|
||||||
(sample_count + prior_samples))
|
|
||||||
rms_estimate_db = 10 * np.log10(mean_squared_estimate)
|
|
||||||
# Compute required time-varying gain.
|
|
||||||
gain_db = target_db - rms_estimate_db
|
|
||||||
self.gain_db(gain_db)
|
|
||||||
|
|
||||||
def resample(self, target_sample_rate, filter='kaiser_best'):
|
|
||||||
"""Resample the audio to a target sample rate.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param target_sample_rate: Target sample rate.
|
|
||||||
:type target_sample_rate: int
|
|
||||||
:param filter: The resampling filter to use one of {'kaiser_best',
|
|
||||||
'kaiser_fast'}.
|
|
||||||
:type filter: str
|
|
||||||
"""
|
|
||||||
self._samples = resampy.resample(
|
|
||||||
self.samples, self.sample_rate, target_sample_rate, filter=filter)
|
|
||||||
self._sample_rate = target_sample_rate
|
|
||||||
|
|
||||||
def pad_silence(self, duration, sides='both'):
|
|
||||||
"""Pad this audio sample with a period of silence.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param duration: Length of silence in seconds to pad.
|
|
||||||
:type duration: float
|
|
||||||
:param sides: Position for padding:
|
|
||||||
'beginning' - adds silence in the beginning;
|
|
||||||
'end' - adds silence in the end;
|
|
||||||
'both' - adds silence in both the beginning and the end.
|
|
||||||
:type sides: str
|
|
||||||
:raises ValueError: If sides is not supported.
|
|
||||||
"""
|
|
||||||
if duration == 0.0:
|
|
||||||
return self
|
|
||||||
cls = type(self)
|
|
||||||
silence = self.make_silence(duration, self._sample_rate)
|
|
||||||
if sides == "beginning":
|
|
||||||
padded = cls.concatenate(silence, self)
|
|
||||||
elif sides == "end":
|
|
||||||
padded = cls.concatenate(self, silence)
|
|
||||||
elif sides == "both":
|
|
||||||
padded = cls.concatenate(silence, self, silence)
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown value for the sides %s" % sides)
|
|
||||||
self._samples = padded._samples
|
|
||||||
|
|
||||||
def shift(self, shift_ms):
|
|
||||||
"""Shift the audio in time. If `shift_ms` is positive, shift with time
|
|
||||||
advance; if negative, shift with time delay. Silence are padded to
|
|
||||||
keep the duration unchanged.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param shift_ms: Shift time in millseconds. If positive, shift with
|
|
||||||
time advance; if negative; shift with time delay.
|
|
||||||
:type shift_ms: float
|
|
||||||
:raises ValueError: If shift_ms is longer than audio duration.
|
|
||||||
"""
|
|
||||||
if abs(shift_ms) / 1000.0 > self.duration:
|
|
||||||
raise ValueError("Absolute value of shift_ms should be smaller "
|
|
||||||
"than audio duration.")
|
|
||||||
shift_samples = int(shift_ms * self._sample_rate / 1000)
|
|
||||||
if shift_samples > 0:
|
|
||||||
# time advance
|
|
||||||
self._samples[:-shift_samples] = self._samples[shift_samples:]
|
|
||||||
self._samples[-shift_samples:] = 0
|
|
||||||
elif shift_samples < 0:
|
|
||||||
# time delay
|
|
||||||
self._samples[-shift_samples:] = self._samples[:shift_samples]
|
|
||||||
self._samples[:-shift_samples] = 0
|
|
||||||
|
|
||||||
def subsegment(self, start_sec=None, end_sec=None):
|
|
||||||
"""Cut the AudioSegment between given boundaries.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param start_sec: Beginning of subsegment in seconds.
|
|
||||||
:type start_sec: float
|
|
||||||
:param end_sec: End of subsegment in seconds.
|
|
||||||
:type end_sec: float
|
|
||||||
:raise ValueError: If start_sec or end_sec is incorrectly set, e.g. out
|
|
||||||
of bounds in time.
|
|
||||||
"""
|
|
||||||
start_sec = 0.0 if start_sec is None else start_sec
|
|
||||||
end_sec = self.duration if end_sec is None else end_sec
|
|
||||||
if start_sec < 0.0:
|
|
||||||
start_sec = self.duration + start_sec
|
|
||||||
if end_sec < 0.0:
|
|
||||||
end_sec = self.duration + end_sec
|
|
||||||
if start_sec < 0.0:
|
|
||||||
raise ValueError("The slice start position (%f s) is out of "
|
|
||||||
"bounds." % start_sec)
|
|
||||||
if end_sec < 0.0:
|
|
||||||
raise ValueError("The slice end position (%f s) is out of bounds." %
|
|
||||||
end_sec)
|
|
||||||
if start_sec > end_sec:
|
|
||||||
raise ValueError("The slice start position (%f s) is later than "
|
|
||||||
"the end position (%f s)." % (start_sec, end_sec))
|
|
||||||
if end_sec > self.duration:
|
|
||||||
raise ValueError("The slice end position (%f s) is out of bounds "
|
|
||||||
"(> %f s)" % (end_sec, self.duration))
|
|
||||||
start_sample = int(round(start_sec * self._sample_rate))
|
|
||||||
end_sample = int(round(end_sec * self._sample_rate))
|
|
||||||
self._samples = self._samples[start_sample:end_sample]
|
|
||||||
|
|
||||||
def random_subsegment(self, subsegment_length, rng=None):
|
|
||||||
"""Cut the specified length of the audiosegment randomly.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param subsegment_length: Subsegment length in seconds.
|
|
||||||
:type subsegment_length: float
|
|
||||||
:param rng: Random number generator state.
|
|
||||||
:type rng: random.Random
|
|
||||||
:raises ValueError: If the length of subsegment is greater than
|
|
||||||
the origineal segemnt.
|
|
||||||
"""
|
|
||||||
rng = random.Random() if rng is None else rng
|
|
||||||
if subsegment_length > self.duration:
|
|
||||||
raise ValueError("Length of subsegment must not be greater "
|
|
||||||
"than original segment.")
|
|
||||||
start_time = rng.uniform(0.0, self.duration - subsegment_length)
|
|
||||||
self.subsegment(start_time, start_time + subsegment_length)
|
|
||||||
|
|
||||||
def convolve(self, impulse_segment, allow_resample=False):
|
|
||||||
"""Convolve this audio segment with the given impulse segment.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param impulse_segment: Impulse response segments.
|
|
||||||
:type impulse_segment: AudioSegment
|
|
||||||
:param allow_resample: Indicates whether resampling is allowed when
|
|
||||||
the impulse_segment has a different sample
|
|
||||||
rate from this signal.
|
|
||||||
:type allow_resample: bool
|
|
||||||
:raises ValueError: If the sample rate is not match between two
|
|
||||||
audio segments when resample is not allowed.
|
|
||||||
"""
|
|
||||||
if allow_resample and self.sample_rate != impulse_segment.sample_rate:
|
|
||||||
impulse_segment.resample(self.sample_rate)
|
|
||||||
if self.sample_rate != impulse_segment.sample_rate:
|
|
||||||
raise ValueError("Impulse segment's sample rate (%d Hz) is not "
|
|
||||||
"equal to base signal sample rate (%d Hz)." %
|
|
||||||
(impulse_segment.sample_rate, self.sample_rate))
|
|
||||||
samples = signal.fftconvolve(self.samples, impulse_segment.samples,
|
|
||||||
"full")
|
|
||||||
self._samples = samples
|
|
||||||
|
|
||||||
def convolve_and_normalize(self, impulse_segment, allow_resample=False):
|
|
||||||
"""Convolve and normalize the resulting audio segment so that it
|
|
||||||
has the same average power as the input signal.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param impulse_segment: Impulse response segments.
|
|
||||||
:type impulse_segment: AudioSegment
|
|
||||||
:param allow_resample: Indicates whether resampling is allowed when
|
|
||||||
the impulse_segment has a different sample
|
|
||||||
rate from this signal.
|
|
||||||
:type allow_resample: bool
|
|
||||||
"""
|
|
||||||
target_db = self.rms_db
|
|
||||||
self.convolve(impulse_segment, allow_resample=allow_resample)
|
|
||||||
self.normalize(target_db)
|
|
||||||
|
|
||||||
def add_noise(self,
|
|
||||||
noise,
|
|
||||||
snr_dB,
|
|
||||||
allow_downsampling=False,
|
|
||||||
max_gain_db=300.0,
|
|
||||||
rng=None):
|
|
||||||
"""Add the given noise segment at a specific signal-to-noise ratio.
|
|
||||||
If the noise segment is longer than this segment, a random subsegment
|
|
||||||
of matching length is sampled from it and used instead.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param noise: Noise signal to add.
|
|
||||||
:type noise: AudioSegment
|
|
||||||
:param snr_dB: Signal-to-Noise Ratio, in decibels.
|
|
||||||
:type snr_dB: float
|
|
||||||
:param allow_downsampling: Whether to allow the noise signal to be
|
|
||||||
downsampled to match the base signal sample
|
|
||||||
rate.
|
|
||||||
:type allow_downsampling: bool
|
|
||||||
:param max_gain_db: Maximum amount of gain to apply to noise signal
|
|
||||||
before adding it in. This is to prevent attempting
|
|
||||||
to apply infinite gain to a zero signal.
|
|
||||||
:type max_gain_db: float
|
|
||||||
:param rng: Random number generator state.
|
|
||||||
:type rng: None|random.Random
|
|
||||||
:raises ValueError: If the sample rate does not match between the two
|
|
||||||
audio segments when downsampling is not allowed, or
|
|
||||||
if the duration of noise segments is shorter than
|
|
||||||
original audio segments.
|
|
||||||
"""
|
|
||||||
rng = random.Random() if rng is None else rng
|
|
||||||
if allow_downsampling and noise.sample_rate > self.sample_rate:
|
|
||||||
noise = noise.resample(self.sample_rate)
|
|
||||||
if noise.sample_rate != self.sample_rate:
|
|
||||||
raise ValueError("Noise sample rate (%d Hz) is not equal to base "
|
|
||||||
"signal sample rate (%d Hz)." % (noise.sample_rate,
|
|
||||||
self.sample_rate))
|
|
||||||
if noise.duration < self.duration:
|
|
||||||
raise ValueError("Noise signal (%f sec) must be at least as long as"
|
|
||||||
" base signal (%f sec)." %
|
|
||||||
(noise.duration, self.duration))
|
|
||||||
noise_gain_db = min(self.rms_db - noise.rms_db - snr_dB, max_gain_db)
|
|
||||||
noise_new = copy.deepcopy(noise)
|
|
||||||
noise_new.random_subsegment(self.duration, rng=rng)
|
|
||||||
noise_new.gain_db(noise_gain_db)
|
|
||||||
self.superimpose(noise_new)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def samples(self):
|
|
||||||
"""Return audio samples.
|
|
||||||
|
|
||||||
:return: Audio samples.
|
|
||||||
:rtype: ndarray
|
|
||||||
"""
|
|
||||||
return self._samples.copy()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sample_rate(self):
|
|
||||||
"""Return audio sample rate.
|
|
||||||
|
|
||||||
:return: Audio sample rate.
|
|
||||||
:rtype: int
|
|
||||||
"""
|
|
||||||
return self._sample_rate
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_samples(self):
|
|
||||||
"""Return number of samples.
|
|
||||||
|
|
||||||
:return: Number of samples.
|
|
||||||
:rtype: int
|
|
||||||
"""
|
|
||||||
return self._samples.shape[0]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def duration(self):
|
|
||||||
"""Return audio duration.
|
|
||||||
|
|
||||||
:return: Audio duration in seconds.
|
|
||||||
:rtype: float
|
|
||||||
"""
|
|
||||||
return self._samples.shape[0] / float(self._sample_rate)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def rms_db(self):
|
|
||||||
"""Return root mean square energy of the audio in decibels.
|
|
||||||
|
|
||||||
:return: Root mean square energy in decibels.
|
|
||||||
:rtype: float
|
|
||||||
"""
|
|
||||||
# square root => multiply by 10 instead of 20 for dBs
|
|
||||||
mean_square = np.mean(self._samples**2)
|
|
||||||
return 10 * np.log10(mean_square)
|
|
||||||
|
|
||||||
def _convert_samples_to_float32(self, samples):
|
|
||||||
"""Convert sample type to float32.
|
|
||||||
|
|
||||||
Audio sample type is usually integer or float-point.
|
|
||||||
Integers will be scaled to [-1, 1] in float32.
|
|
||||||
"""
|
|
||||||
float32_samples = samples.astype('float32')
|
|
||||||
if samples.dtype in np.sctypes['int']:
|
|
||||||
bits = np.iinfo(samples.dtype).bits
|
|
||||||
float32_samples *= (1. / 2**(bits - 1))
|
|
||||||
elif samples.dtype in np.sctypes['float']:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise TypeError("Unsupported sample type: %s." % samples.dtype)
|
|
||||||
return float32_samples
|
|
||||||
|
|
||||||
def _convert_samples_from_float32(self, samples, dtype):
|
|
||||||
"""Convert sample type from float32 to dtype.
|
|
||||||
|
|
||||||
Audio sample type is usually integer or float-point. For integer
|
|
||||||
type, float32 will be rescaled from [-1, 1] to the maximum range
|
|
||||||
supported by the integer type.
|
|
||||||
|
|
||||||
This is for writing a audio file.
|
|
||||||
"""
|
|
||||||
dtype = np.dtype(dtype)
|
|
||||||
output_samples = samples.copy()
|
|
||||||
if dtype in np.sctypes['int']:
|
|
||||||
bits = np.iinfo(dtype).bits
|
|
||||||
output_samples *= (2**(bits - 1) / 1.)
|
|
||||||
min_val = np.iinfo(dtype).min
|
|
||||||
max_val = np.iinfo(dtype).max
|
|
||||||
output_samples[output_samples > max_val] = max_val
|
|
||||||
output_samples[output_samples < min_val] = min_val
|
|
||||||
elif samples.dtype in np.sctypes['float']:
|
|
||||||
min_val = np.finfo(dtype).min
|
|
||||||
max_val = np.finfo(dtype).max
|
|
||||||
output_samples[output_samples > max_val] = max_val
|
|
||||||
output_samples[output_samples < min_val] = min_val
|
|
||||||
else:
|
|
||||||
raise TypeError("Unsupported sample type: %s." % samples.dtype)
|
|
||||||
return output_samples.astype(dtype)
|
|
||||||
@ -1,13 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
@ -1,134 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains the data augmentation pipeline."""
|
|
||||||
import json
|
|
||||||
import random
|
|
||||||
|
|
||||||
from data_utils.augmentor.impulse_response import ImpulseResponseAugmentor
|
|
||||||
from data_utils.augmentor.noise_perturb import NoisePerturbAugmentor
|
|
||||||
from data_utils.augmentor.online_bayesian_normalization import \
|
|
||||||
OnlineBayesianNormalizationAugmentor
|
|
||||||
from data_utils.augmentor.resample import ResampleAugmentor
|
|
||||||
from data_utils.augmentor.shift_perturb import ShiftPerturbAugmentor
|
|
||||||
from data_utils.augmentor.speed_perturb import SpeedPerturbAugmentor
|
|
||||||
from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor
|
|
||||||
|
|
||||||
|
|
||||||
class AugmentationPipeline(object):
|
|
||||||
"""Build a pre-processing pipeline with various augmentation models.Such a
|
|
||||||
data augmentation pipeline is oftern leveraged to augment the training
|
|
||||||
samples to make the model invariant to certain types of perturbations in the
|
|
||||||
real world, improving model's generalization ability.
|
|
||||||
|
|
||||||
The pipeline is built according the the augmentation configuration in json
|
|
||||||
string, e.g.
|
|
||||||
|
|
||||||
.. code-block::
|
|
||||||
|
|
||||||
[ {
|
|
||||||
"type": "noise",
|
|
||||||
"params": {"min_snr_dB": 10,
|
|
||||||
"max_snr_dB": 20,
|
|
||||||
"noise_manifest_path": "datasets/manifest.noise"},
|
|
||||||
"prob": 0.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "speed",
|
|
||||||
"params": {"min_speed_rate": 0.9,
|
|
||||||
"max_speed_rate": 1.1},
|
|
||||||
"prob": 1.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "shift",
|
|
||||||
"params": {"min_shift_ms": -5,
|
|
||||||
"max_shift_ms": 5},
|
|
||||||
"prob": 1.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "volume",
|
|
||||||
"params": {"min_gain_dBFS": -10,
|
|
||||||
"max_gain_dBFS": 10},
|
|
||||||
"prob": 0.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "bayesian_normal",
|
|
||||||
"params": {"target_db": -20,
|
|
||||||
"prior_db": -20,
|
|
||||||
"prior_samples": 100},
|
|
||||||
"prob": 0.0
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
This augmentation configuration inserts two augmentation models
|
|
||||||
into the pipeline, with one is VolumePerturbAugmentor and the other
|
|
||||||
SpeedPerturbAugmentor. "prob" indicates the probability of the current
|
|
||||||
augmentor to take effect. If "prob" is zero, the augmentor does not take
|
|
||||||
effect.
|
|
||||||
|
|
||||||
:param augmentation_config: Augmentation configuration in json string.
|
|
||||||
:type augmentation_config: str
|
|
||||||
:param random_seed: Random seed.
|
|
||||||
:type random_seed: int
|
|
||||||
:raises ValueError: If the augmentation json config is in incorrect format".
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, augmentation_config, random_seed=0):
|
|
||||||
self._rng = random.Random(random_seed)
|
|
||||||
self._augmentors, self._rates = self._parse_pipeline_from(
|
|
||||||
augmentation_config)
|
|
||||||
|
|
||||||
def transform_audio(self, audio_segment):
|
|
||||||
"""Run the pre-processing pipeline for data augmentation.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param audio_segment: Audio segment to process.
|
|
||||||
:type audio_segment: AudioSegmenet|SpeechSegment
|
|
||||||
"""
|
|
||||||
for augmentor, rate in zip(self._augmentors, self._rates):
|
|
||||||
if self._rng.uniform(0., 1.) < rate:
|
|
||||||
augmentor.transform_audio(audio_segment)
|
|
||||||
|
|
||||||
def _parse_pipeline_from(self, config_json):
|
|
||||||
"""Parse the config json to build a augmentation pipelien."""
|
|
||||||
try:
|
|
||||||
configs = json.loads(config_json)
|
|
||||||
augmentors = [
|
|
||||||
self._get_augmentor(config["type"], config["params"])
|
|
||||||
for config in configs
|
|
||||||
]
|
|
||||||
rates = [config["prob"] for config in configs]
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError("Failed to parse the augmentation config json: "
|
|
||||||
"%s" % str(e))
|
|
||||||
return augmentors, rates
|
|
||||||
|
|
||||||
def _get_augmentor(self, augmentor_type, params):
|
|
||||||
"""Return an augmentation model by the type name, and pass in params."""
|
|
||||||
if augmentor_type == "volume":
|
|
||||||
return VolumePerturbAugmentor(self._rng, **params)
|
|
||||||
elif augmentor_type == "shift":
|
|
||||||
return ShiftPerturbAugmentor(self._rng, **params)
|
|
||||||
elif augmentor_type == "speed":
|
|
||||||
return SpeedPerturbAugmentor(self._rng, **params)
|
|
||||||
elif augmentor_type == "resample":
|
|
||||||
return ResampleAugmentor(self._rng, **params)
|
|
||||||
elif augmentor_type == "bayesian_normal":
|
|
||||||
return OnlineBayesianNormalizationAugmentor(self._rng, **params)
|
|
||||||
elif augmentor_type == "noise":
|
|
||||||
return NoisePerturbAugmentor(self._rng, **params)
|
|
||||||
elif augmentor_type == "impulse":
|
|
||||||
return ImpulseResponseAugmentor(self._rng, **params)
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown augmentor type [%s]." % augmentor_type)
|
|
||||||
@ -1,43 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains the abstract base class for augmentation models."""
|
|
||||||
from abc import ABCMeta
|
|
||||||
from abc import abstractmethod
|
|
||||||
|
|
||||||
|
|
||||||
class AugmentorBase(object):
|
|
||||||
"""Abstract base class for augmentation model (augmentor) class.
|
|
||||||
All augmentor classes should inherit from this class, and implement the
|
|
||||||
following abstract methods.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__metaclass__ = ABCMeta
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def transform_audio(self, audio_segment):
|
|
||||||
"""Adds various effects to the input audio segment. Such effects
|
|
||||||
will augment the training data to make the model invariant to certain
|
|
||||||
types of perturbations in the real world, improving model's
|
|
||||||
generalization ability.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param audio_segment: Audio segment to add effects to.
|
|
||||||
:type audio_segment: AudioSegmenet|SpeechSegment
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
@ -1,43 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains the impulse response augmentation model."""
|
|
||||||
from data_utils.audio import AudioSegment
|
|
||||||
from data_utils.augmentor.base import AugmentorBase
|
|
||||||
from data_utils.utility import read_manifest
|
|
||||||
|
|
||||||
|
|
||||||
class ImpulseResponseAugmentor(AugmentorBase):
|
|
||||||
"""Augmentation model for adding impulse response effect.
|
|
||||||
|
|
||||||
:param rng: Random generator object.
|
|
||||||
:type rng: random.Random
|
|
||||||
:param impulse_manifest_path: Manifest path for impulse audio data.
|
|
||||||
:type impulse_manifest_path: str
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, rng, impulse_manifest_path):
|
|
||||||
self._rng = rng
|
|
||||||
self._impulse_manifest = read_manifest(impulse_manifest_path)
|
|
||||||
|
|
||||||
def transform_audio(self, audio_segment):
|
|
||||||
"""Add impulse response effect.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param audio_segment: Audio segment to add effects to.
|
|
||||||
:type audio_segment: AudioSegmenet|SpeechSegment
|
|
||||||
"""
|
|
||||||
impulse_json = self._rng.sample(self._impulse_manifest, 1)[0]
|
|
||||||
impulse_segment = AudioSegment.from_file(impulse_json['audio_filepath'])
|
|
||||||
audio_segment.convolve(impulse_segment, allow_resample=True)
|
|
||||||
@ -1,58 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains the noise perturb augmentation model."""
|
|
||||||
from data_utils.audio import AudioSegment
|
|
||||||
from data_utils.augmentor.base import AugmentorBase
|
|
||||||
from data_utils.utility import read_manifest
|
|
||||||
|
|
||||||
|
|
||||||
class NoisePerturbAugmentor(AugmentorBase):
|
|
||||||
"""Augmentation model for adding background noise.
|
|
||||||
|
|
||||||
:param rng: Random generator object.
|
|
||||||
:type rng: random.Random
|
|
||||||
:param min_snr_dB: Minimal signal noise ratio, in decibels.
|
|
||||||
:type min_snr_dB: float
|
|
||||||
:param max_snr_dB: Maximal signal noise ratio, in decibels.
|
|
||||||
:type max_snr_dB: float
|
|
||||||
:param noise_manifest_path: Manifest path for noise audio data.
|
|
||||||
:type noise_manifest_path: str
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, rng, min_snr_dB, max_snr_dB, noise_manifest_path):
|
|
||||||
self._min_snr_dB = min_snr_dB
|
|
||||||
self._max_snr_dB = max_snr_dB
|
|
||||||
self._rng = rng
|
|
||||||
self._noise_manifest = read_manifest(manifest_path=noise_manifest_path)
|
|
||||||
|
|
||||||
def transform_audio(self, audio_segment):
|
|
||||||
"""Add background noise audio.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param audio_segment: Audio segment to add effects to.
|
|
||||||
:type audio_segment: AudioSegmenet|SpeechSegment
|
|
||||||
"""
|
|
||||||
noise_json = self._rng.sample(self._noise_manifest, 1)[0]
|
|
||||||
if noise_json['duration'] < audio_segment.duration:
|
|
||||||
raise RuntimeError("The duration of sampled noise audio is smaller "
|
|
||||||
"than the audio segment to add effects to.")
|
|
||||||
diff_duration = noise_json['duration'] - audio_segment.duration
|
|
||||||
start = self._rng.uniform(0, diff_duration)
|
|
||||||
end = start + audio_segment.duration
|
|
||||||
noise_segment = AudioSegment.slice_from_file(
|
|
||||||
noise_json['audio_filepath'], start=start, end=end)
|
|
||||||
snr_dB = self._rng.uniform(self._min_snr_dB, self._max_snr_dB)
|
|
||||||
audio_segment.add_noise(
|
|
||||||
noise_segment, snr_dB, allow_downsampling=True, rng=self._rng)
|
|
||||||
@ -1,57 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contain the online bayesian normalization augmentation model."""
|
|
||||||
from data_utils.augmentor.base import AugmentorBase
|
|
||||||
|
|
||||||
|
|
||||||
class OnlineBayesianNormalizationAugmentor(AugmentorBase):
|
|
||||||
"""Augmentation model for adding online bayesian normalization.
|
|
||||||
|
|
||||||
:param rng: Random generator object.
|
|
||||||
:type rng: random.Random
|
|
||||||
:param target_db: Target RMS value in decibels.
|
|
||||||
:type target_db: float
|
|
||||||
:param prior_db: Prior RMS estimate in decibels.
|
|
||||||
:type prior_db: float
|
|
||||||
:param prior_samples: Prior strength in number of samples.
|
|
||||||
:type prior_samples: int
|
|
||||||
:param startup_delay: Default 0.0s. If provided, this function will
|
|
||||||
accrue statistics for the first startup_delay
|
|
||||||
seconds before applying online normalization.
|
|
||||||
:type starup_delay: float.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
rng,
|
|
||||||
target_db,
|
|
||||||
prior_db,
|
|
||||||
prior_samples,
|
|
||||||
startup_delay=0.0):
|
|
||||||
self._target_db = target_db
|
|
||||||
self._prior_db = prior_db
|
|
||||||
self._prior_samples = prior_samples
|
|
||||||
self._rng = rng
|
|
||||||
self._startup_delay = startup_delay
|
|
||||||
|
|
||||||
def transform_audio(self, audio_segment):
|
|
||||||
"""Normalizes the input audio using the online Bayesian approach.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param audio_segment: Audio segment to add effects to.
|
|
||||||
:type audio_segment: AudioSegment|SpeechSegment
|
|
||||||
"""
|
|
||||||
audio_segment.normalize_online_bayesian(self._target_db, self._prior_db,
|
|
||||||
self._prior_samples,
|
|
||||||
self._startup_delay)
|
|
||||||
@ -1,42 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contain the resample augmentation model."""
|
|
||||||
from data_utils.augmentor.base import AugmentorBase
|
|
||||||
|
|
||||||
|
|
||||||
class ResampleAugmentor(AugmentorBase):
|
|
||||||
"""Augmentation model for resampling.
|
|
||||||
|
|
||||||
See more info here:
|
|
||||||
https://ccrma.stanford.edu/~jos/resample/index.html
|
|
||||||
|
|
||||||
:param rng: Random generator object.
|
|
||||||
:type rng: random.Random
|
|
||||||
:param new_sample_rate: New sample rate in Hz.
|
|
||||||
:type new_sample_rate: int
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, rng, new_sample_rate):
|
|
||||||
self._new_sample_rate = new_sample_rate
|
|
||||||
self._rng = rng
|
|
||||||
|
|
||||||
def transform_audio(self, audio_segment):
|
|
||||||
"""Resamples the input audio to a target sample rate.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param audio: Audio segment to add effects to.
|
|
||||||
:type audio: AudioSegment|SpeechSegment
|
|
||||||
"""
|
|
||||||
audio_segment.resample(self._new_sample_rate)
|
|
||||||
@ -1,43 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains the volume perturb augmentation model."""
|
|
||||||
from data_utils.augmentor.base import AugmentorBase
|
|
||||||
|
|
||||||
|
|
||||||
class ShiftPerturbAugmentor(AugmentorBase):
|
|
||||||
"""Augmentation model for adding random shift perturbation.
|
|
||||||
|
|
||||||
:param rng: Random generator object.
|
|
||||||
:type rng: random.Random
|
|
||||||
:param min_shift_ms: Minimal shift in milliseconds.
|
|
||||||
:type min_shift_ms: float
|
|
||||||
:param max_shift_ms: Maximal shift in milliseconds.
|
|
||||||
:type max_shift_ms: float
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, rng, min_shift_ms, max_shift_ms):
|
|
||||||
self._min_shift_ms = min_shift_ms
|
|
||||||
self._max_shift_ms = max_shift_ms
|
|
||||||
self._rng = rng
|
|
||||||
|
|
||||||
def transform_audio(self, audio_segment):
|
|
||||||
"""Shift audio.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param audio_segment: Audio segment to add effects to.
|
|
||||||
:type audio_segment: AudioSegmenet|SpeechSegment
|
|
||||||
"""
|
|
||||||
shift_ms = self._rng.uniform(self._min_shift_ms, self._max_shift_ms)
|
|
||||||
audio_segment.shift(shift_ms)
|
|
||||||
@ -1,56 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contain the speech perturbation augmentation model."""
|
|
||||||
from data_utils.augmentor.base import AugmentorBase
|
|
||||||
|
|
||||||
|
|
||||||
class SpeedPerturbAugmentor(AugmentorBase):
|
|
||||||
"""Augmentation model for adding speed perturbation.
|
|
||||||
|
|
||||||
See reference paper here:
|
|
||||||
http://www.danielpovey.com/files/2015_interspeech_augmentation.pdf
|
|
||||||
|
|
||||||
:param rng: Random generator object.
|
|
||||||
:type rng: random.Random
|
|
||||||
:param min_speed_rate: Lower bound of new speed rate to sample and should
|
|
||||||
not be smaller than 0.9.
|
|
||||||
:type min_speed_rate: float
|
|
||||||
:param max_speed_rate: Upper bound of new speed rate to sample and should
|
|
||||||
not be larger than 1.1.
|
|
||||||
:type max_speed_rate: float
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, rng, min_speed_rate, max_speed_rate):
|
|
||||||
if min_speed_rate < 0.9:
|
|
||||||
raise ValueError(
|
|
||||||
"Sampling speed below 0.9 can cause unnatural effects")
|
|
||||||
if max_speed_rate > 1.1:
|
|
||||||
raise ValueError(
|
|
||||||
"Sampling speed above 1.1 can cause unnatural effects")
|
|
||||||
self._min_speed_rate = min_speed_rate
|
|
||||||
self._max_speed_rate = max_speed_rate
|
|
||||||
self._rng = rng
|
|
||||||
|
|
||||||
def transform_audio(self, audio_segment):
|
|
||||||
"""Sample a new speed rate from the given range and
|
|
||||||
changes the speed of the given audio clip.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param audio_segment: Audio segment to add effects to.
|
|
||||||
:type audio_segment: AudioSegment|SpeechSegment
|
|
||||||
"""
|
|
||||||
sampled_speed = self._rng.uniform(self._min_speed_rate,
|
|
||||||
self._max_speed_rate)
|
|
||||||
audio_segment.change_speed(sampled_speed)
|
|
||||||
@ -1,49 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains the volume perturb augmentation model."""
|
|
||||||
from data_utils.augmentor.base import AugmentorBase
|
|
||||||
|
|
||||||
|
|
||||||
class VolumePerturbAugmentor(AugmentorBase):
|
|
||||||
"""Augmentation model for adding random volume perturbation.
|
|
||||||
|
|
||||||
This is used for multi-loudness training of PCEN. See
|
|
||||||
|
|
||||||
https://arxiv.org/pdf/1607.05666v1.pdf
|
|
||||||
|
|
||||||
for more details.
|
|
||||||
|
|
||||||
:param rng: Random generator object.
|
|
||||||
:type rng: random.Random
|
|
||||||
:param min_gain_dBFS: Minimal gain in dBFS.
|
|
||||||
:type min_gain_dBFS: float
|
|
||||||
:param max_gain_dBFS: Maximal gain in dBFS.
|
|
||||||
:type max_gain_dBFS: float
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, rng, min_gain_dBFS, max_gain_dBFS):
|
|
||||||
self._min_gain_dBFS = min_gain_dBFS
|
|
||||||
self._max_gain_dBFS = max_gain_dBFS
|
|
||||||
self._rng = rng
|
|
||||||
|
|
||||||
def transform_audio(self, audio_segment):
|
|
||||||
"""Change audio loadness.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param audio_segment: Audio segment to add effects to.
|
|
||||||
:type audio_segment: AudioSegmenet|SpeechSegment
|
|
||||||
"""
|
|
||||||
gain = self._rng.uniform(self._min_gain_dBFS, self._max_gain_dBFS)
|
|
||||||
audio_segment.gain_db(gain)
|
|
||||||
@ -1,380 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains data generator for orgnaizing various audio data preprocessing
|
|
||||||
pipeline and offering data reader interface of PaddlePaddle requirements.
|
|
||||||
"""
|
|
||||||
import random
|
|
||||||
import tarfile
|
|
||||||
from threading import local
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import paddle.fluid as fluid
|
|
||||||
from data_utils.augmentor.augmentation import AugmentationPipeline
|
|
||||||
from data_utils.featurizer.speech_featurizer import SpeechFeaturizer
|
|
||||||
from data_utils.normalizer import FeatureNormalizer
|
|
||||||
from data_utils.speech import SpeechSegment
|
|
||||||
from data_utils.utility import read_manifest
|
|
||||||
|
|
||||||
|
|
||||||
class DataGenerator(object):
|
|
||||||
"""
|
|
||||||
DataGenerator provides basic audio data preprocessing pipeline, and offers
|
|
||||||
data reader interfaces of PaddlePaddle requirements.
|
|
||||||
|
|
||||||
:param vocab_filepath: Vocabulary filepath for indexing tokenized
|
|
||||||
transcripts.
|
|
||||||
:type vocab_filepath: str
|
|
||||||
:param mean_std_filepath: File containing the pre-computed mean and stddev.
|
|
||||||
:type mean_std_filepath: None|str
|
|
||||||
:param augmentation_config: Augmentation configuration in json string.
|
|
||||||
Details see AugmentationPipeline.__doc__.
|
|
||||||
:type augmentation_config: str
|
|
||||||
:param max_duration: Audio with duration (in seconds) greater than
|
|
||||||
this will be discarded.
|
|
||||||
:type max_duration: float
|
|
||||||
:param min_duration: Audio with duration (in seconds) smaller than
|
|
||||||
this will be discarded.
|
|
||||||
:type min_duration: float
|
|
||||||
:param stride_ms: Striding size (in milliseconds) for generating frames.
|
|
||||||
:type stride_ms: float
|
|
||||||
:param window_ms: Window size (in milliseconds) for generating frames.
|
|
||||||
:type window_ms: float
|
|
||||||
:param max_freq: Used when specgram_type is 'linear', only FFT bins
|
|
||||||
corresponding to frequencies between [0, max_freq] are
|
|
||||||
returned.
|
|
||||||
:types max_freq: None|float
|
|
||||||
:param specgram_type: Specgram feature type. Options: 'linear'.
|
|
||||||
:type specgram_type: str
|
|
||||||
:param use_dB_normalization: Whether to normalize the audio to -20 dB
|
|
||||||
before extracting the features.
|
|
||||||
:type use_dB_normalization: bool
|
|
||||||
:param random_seed: Random seed.
|
|
||||||
:type random_seed: int
|
|
||||||
:param keep_transcription_text: If set to True, transcription text will
|
|
||||||
be passed forward directly without
|
|
||||||
converting to index sequence.
|
|
||||||
:type keep_transcription_text: bool
|
|
||||||
:param place: The place to run the program.
|
|
||||||
:type place: CPUPlace or CUDAPlace
|
|
||||||
:param is_training: If set to True, generate text data for training,
|
|
||||||
otherwise, generate text data for infer.
|
|
||||||
:type is_training: bool
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
vocab_filepath,
|
|
||||||
mean_std_filepath,
|
|
||||||
augmentation_config='{}',
|
|
||||||
max_duration=float('inf'),
|
|
||||||
min_duration=0.0,
|
|
||||||
stride_ms=10.0,
|
|
||||||
window_ms=20.0,
|
|
||||||
max_freq=None,
|
|
||||||
specgram_type='linear',
|
|
||||||
use_dB_normalization=True,
|
|
||||||
random_seed=0,
|
|
||||||
keep_transcription_text=False,
|
|
||||||
place=fluid.CPUPlace(),
|
|
||||||
is_training=True):
|
|
||||||
self._max_duration = max_duration
|
|
||||||
self._min_duration = min_duration
|
|
||||||
self._normalizer = FeatureNormalizer(mean_std_filepath)
|
|
||||||
self._augmentation_pipeline = AugmentationPipeline(
|
|
||||||
augmentation_config=augmentation_config, random_seed=random_seed)
|
|
||||||
self._speech_featurizer = SpeechFeaturizer(
|
|
||||||
vocab_filepath=vocab_filepath,
|
|
||||||
specgram_type=specgram_type,
|
|
||||||
stride_ms=stride_ms,
|
|
||||||
window_ms=window_ms,
|
|
||||||
max_freq=max_freq,
|
|
||||||
use_dB_normalization=use_dB_normalization)
|
|
||||||
self._rng = random.Random(random_seed)
|
|
||||||
self._keep_transcription_text = keep_transcription_text
|
|
||||||
self._epoch = 0
|
|
||||||
self._is_training = is_training
|
|
||||||
# for caching tar files info
|
|
||||||
self._local_data = local()
|
|
||||||
self._local_data.tar2info = {}
|
|
||||||
self._local_data.tar2object = {}
|
|
||||||
self._place = place
|
|
||||||
|
|
||||||
def process_utterance(self, audio_file, transcript):
|
|
||||||
"""Load, augment, featurize and normalize for speech data.
|
|
||||||
|
|
||||||
:param audio_file: Filepath or file object of audio file.
|
|
||||||
:type audio_file: str | file
|
|
||||||
:param transcript: Transcription text.
|
|
||||||
:type transcript: str
|
|
||||||
:return: Tuple of audio feature tensor and data of transcription part,
|
|
||||||
where transcription part could be token ids or text.
|
|
||||||
:rtype: tuple of (2darray, list)
|
|
||||||
"""
|
|
||||||
if isinstance(audio_file, str) and audio_file.startswith('tar:'):
|
|
||||||
speech_segment = SpeechSegment.from_file(
|
|
||||||
self._subfile_from_tar(audio_file), transcript)
|
|
||||||
else:
|
|
||||||
speech_segment = SpeechSegment.from_file(audio_file, transcript)
|
|
||||||
self._augmentation_pipeline.transform_audio(speech_segment)
|
|
||||||
specgram, transcript_part = self._speech_featurizer.featurize(
|
|
||||||
speech_segment, self._keep_transcription_text)
|
|
||||||
specgram = self._normalizer.apply(specgram)
|
|
||||||
return specgram, transcript_part
|
|
||||||
|
|
||||||
def batch_reader_creator(self,
|
|
||||||
manifest_path,
|
|
||||||
batch_size,
|
|
||||||
padding_to=-1,
|
|
||||||
flatten=False,
|
|
||||||
sortagrad=False,
|
|
||||||
shuffle_method="batch_shuffle"):
|
|
||||||
"""
|
|
||||||
Batch data reader creator for audio data. Return a callable generator
|
|
||||||
function to produce batches of data.
|
|
||||||
|
|
||||||
Audio features within one batch will be padded with zeros to have the
|
|
||||||
same shape, or a user-defined shape.
|
|
||||||
|
|
||||||
:param manifest_path: Filepath of manifest for audio files.
|
|
||||||
:type manifest_path: str
|
|
||||||
:param batch_size: Number of instances in a batch.
|
|
||||||
:type batch_size: int
|
|
||||||
:param padding_to: If set -1, the maximun shape in the batch
|
|
||||||
will be used as the target shape for padding.
|
|
||||||
Otherwise, `padding_to` will be the target shape.
|
|
||||||
:type padding_to: int
|
|
||||||
:param flatten: If set True, audio features will be flatten to 1darray.
|
|
||||||
:type flatten: bool
|
|
||||||
:param sortagrad: If set True, sort the instances by audio duration
|
|
||||||
in the first epoch for speed up training.
|
|
||||||
:type sortagrad: bool
|
|
||||||
:param shuffle_method: Shuffle method. Options:
|
|
||||||
'' or None: no shuffle.
|
|
||||||
'instance_shuffle': instance-wise shuffle.
|
|
||||||
'batch_shuffle': similarly-sized instances are
|
|
||||||
put into batches, and then
|
|
||||||
batch-wise shuffle the batches.
|
|
||||||
For more details, please see
|
|
||||||
``_batch_shuffle.__doc__``.
|
|
||||||
'batch_shuffle_clipped': 'batch_shuffle' with
|
|
||||||
head shift and tail
|
|
||||||
clipping. For more
|
|
||||||
details, please see
|
|
||||||
``_batch_shuffle``.
|
|
||||||
If sortagrad is True, shuffle is disabled
|
|
||||||
for the first epoch.
|
|
||||||
:type shuffle_method: None|str
|
|
||||||
:return: Batch reader function, producing batches of data when called.
|
|
||||||
:rtype: callable
|
|
||||||
"""
|
|
||||||
|
|
||||||
def batch_reader():
|
|
||||||
# read manifest
|
|
||||||
manifest = read_manifest(
|
|
||||||
manifest_path=manifest_path,
|
|
||||||
max_duration=self._max_duration,
|
|
||||||
min_duration=self._min_duration)
|
|
||||||
# sort (by duration) or batch-wise shuffle the manifest
|
|
||||||
if self._epoch == 0 and sortagrad:
|
|
||||||
manifest.sort(key=lambda x: x["duration"])
|
|
||||||
|
|
||||||
else:
|
|
||||||
if shuffle_method == "batch_shuffle":
|
|
||||||
manifest = self._batch_shuffle(
|
|
||||||
manifest, batch_size, clipped=False)
|
|
||||||
elif shuffle_method == "batch_shuffle_clipped":
|
|
||||||
manifest = self._batch_shuffle(
|
|
||||||
manifest, batch_size, clipped=True)
|
|
||||||
elif shuffle_method == "instance_shuffle":
|
|
||||||
self._rng.shuffle(manifest)
|
|
||||||
elif shuffle_method is None:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown shuffle method %s." %
|
|
||||||
shuffle_method)
|
|
||||||
# prepare batches
|
|
||||||
batch = []
|
|
||||||
instance_reader = self._instance_reader_creator(manifest)
|
|
||||||
|
|
||||||
for instance in instance_reader():
|
|
||||||
batch.append(instance)
|
|
||||||
if len(batch) == batch_size:
|
|
||||||
yield self._padding_batch(batch, padding_to, flatten)
|
|
||||||
batch = []
|
|
||||||
if len(batch) >= 1:
|
|
||||||
yield self._padding_batch(batch, padding_to, flatten)
|
|
||||||
self._epoch += 1
|
|
||||||
|
|
||||||
return batch_reader
|
|
||||||
|
|
||||||
@property
|
|
||||||
def feeding(self):
|
|
||||||
"""Returns data reader's feeding dict.
|
|
||||||
|
|
||||||
:return: Data feeding dict.
|
|
||||||
:rtype: dict
|
|
||||||
"""
|
|
||||||
feeding_dict = {"audio_spectrogram": 0, "transcript_text": 1}
|
|
||||||
return feeding_dict
|
|
||||||
|
|
||||||
@property
|
|
||||||
def vocab_size(self):
|
|
||||||
"""Return the vocabulary size.
|
|
||||||
|
|
||||||
:return: Vocabulary size.
|
|
||||||
:rtype: int
|
|
||||||
"""
|
|
||||||
return self._speech_featurizer.vocab_size
|
|
||||||
|
|
||||||
@property
|
|
||||||
def vocab_list(self):
|
|
||||||
"""Return the vocabulary in list.
|
|
||||||
|
|
||||||
:return: Vocabulary in list.
|
|
||||||
:rtype: list
|
|
||||||
"""
|
|
||||||
return self._speech_featurizer.vocab_list
|
|
||||||
|
|
||||||
def _parse_tar(self, file):
|
|
||||||
"""Parse a tar file to get a tarfile object
|
|
||||||
and a map containing tarinfoes
|
|
||||||
"""
|
|
||||||
result = {}
|
|
||||||
f = tarfile.open(file)
|
|
||||||
for tarinfo in f.getmembers():
|
|
||||||
result[tarinfo.name] = tarinfo
|
|
||||||
return f, result
|
|
||||||
|
|
||||||
def _subfile_from_tar(self, file):
|
|
||||||
"""Get subfile object from tar.
|
|
||||||
|
|
||||||
It will return a subfile object from tar file
|
|
||||||
and cached tar file info for next reading request.
|
|
||||||
"""
|
|
||||||
tarpath, filename = file.split(':', 1)[1].split('#', 1)
|
|
||||||
if 'tar2info' not in self._local_data.__dict__:
|
|
||||||
self._local_data.tar2info = {}
|
|
||||||
if 'tar2object' not in self._local_data.__dict__:
|
|
||||||
self._local_data.tar2object = {}
|
|
||||||
if tarpath not in self._local_data.tar2info:
|
|
||||||
object, infoes = self._parse_tar(tarpath)
|
|
||||||
self._local_data.tar2info[tarpath] = infoes
|
|
||||||
self._local_data.tar2object[tarpath] = object
|
|
||||||
return self._local_data.tar2object[tarpath].extractfile(
|
|
||||||
self._local_data.tar2info[tarpath][filename])
|
|
||||||
|
|
||||||
def _instance_reader_creator(self, manifest):
|
|
||||||
"""
|
|
||||||
Instance reader creator. Create a callable function to produce
|
|
||||||
instances of data.
|
|
||||||
|
|
||||||
Instance: a tuple of ndarray of audio spectrogram and a list of
|
|
||||||
token indices for transcript.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def reader():
|
|
||||||
for instance in manifest:
|
|
||||||
inst = self.process_utterance(instance["audio_filepath"],
|
|
||||||
instance["text"])
|
|
||||||
yield inst
|
|
||||||
|
|
||||||
return reader
|
|
||||||
|
|
||||||
def _padding_batch(self, batch, padding_to=-1, flatten=False):
|
|
||||||
"""
|
|
||||||
Padding audio features with zeros to make them have the same shape (or
|
|
||||||
a user-defined shape) within one bach.
|
|
||||||
|
|
||||||
If ``padding_to`` is -1, the maximun shape in the batch will be used
|
|
||||||
as the target shape for padding. Otherwise, `padding_to` will be the
|
|
||||||
target shape (only refers to the second axis).
|
|
||||||
|
|
||||||
If `flatten` is True, features will be flatten to 1darray.
|
|
||||||
"""
|
|
||||||
new_batch = []
|
|
||||||
# get target shape
|
|
||||||
max_length = max([audio.shape[1] for audio, text in batch])
|
|
||||||
if padding_to != -1:
|
|
||||||
if padding_to < max_length:
|
|
||||||
raise ValueError("If padding_to is not -1, it should be larger "
|
|
||||||
"than any instance's shape in the batch")
|
|
||||||
max_length = padding_to
|
|
||||||
# padding
|
|
||||||
padded_audios = []
|
|
||||||
texts, text_lens = [], []
|
|
||||||
audio_lens = []
|
|
||||||
masks = []
|
|
||||||
for audio, text in batch:
|
|
||||||
padded_audio = np.zeros([audio.shape[0], max_length])
|
|
||||||
padded_audio[:, :audio.shape[1]] = audio
|
|
||||||
if flatten:
|
|
||||||
padded_audio = padded_audio.flatten()
|
|
||||||
padded_audios.append(padded_audio)
|
|
||||||
if self._is_training:
|
|
||||||
texts += text
|
|
||||||
else:
|
|
||||||
texts.append(text)
|
|
||||||
text_lens.append(len(text))
|
|
||||||
audio_lens.append(audio.shape[1])
|
|
||||||
mask_shape0 = (audio.shape[0] - 1) // 2 + 1
|
|
||||||
mask_shape1 = (audio.shape[1] - 1) // 3 + 1
|
|
||||||
mask_max_len = (max_length - 1) // 3 + 1
|
|
||||||
mask_ones = np.ones((mask_shape0, mask_shape1))
|
|
||||||
mask_zeros = np.zeros((mask_shape0, mask_max_len - mask_shape1))
|
|
||||||
mask = np.repeat(
|
|
||||||
np.reshape(
|
|
||||||
np.concatenate((mask_ones, mask_zeros), axis=1),
|
|
||||||
(1, mask_shape0, mask_max_len)),
|
|
||||||
32,
|
|
||||||
axis=0)
|
|
||||||
masks.append(mask)
|
|
||||||
padded_audios = np.array(padded_audios).astype('float32')
|
|
||||||
if self._is_training:
|
|
||||||
texts = np.expand_dims(np.array(texts).astype('int32'), axis=-1)
|
|
||||||
texts = fluid.create_lod_tensor(
|
|
||||||
texts, recursive_seq_lens=[text_lens], place=self._place)
|
|
||||||
audio_lens = np.array(audio_lens).astype('int64').reshape([-1, 1])
|
|
||||||
masks = np.array(masks).astype('float32')
|
|
||||||
return padded_audios, texts, audio_lens, masks
|
|
||||||
|
|
||||||
def _batch_shuffle(self, manifest, batch_size, clipped=False):
|
|
||||||
"""Put similarly-sized instances into minibatches for better efficiency
|
|
||||||
and make a batch-wise shuffle.
|
|
||||||
|
|
||||||
1. Sort the audio clips by duration.
|
|
||||||
2. Generate a random number `k`, k in [0, batch_size).
|
|
||||||
3. Randomly shift `k` instances in order to create different batches
|
|
||||||
for different epochs. Create minibatches.
|
|
||||||
4. Shuffle the minibatches.
|
|
||||||
|
|
||||||
:param manifest: Manifest contents. List of dict.
|
|
||||||
:type manifest: list
|
|
||||||
:param batch_size: Batch size. This size is also used for generate
|
|
||||||
a random number for batch shuffle.
|
|
||||||
:type batch_size: int
|
|
||||||
:param clipped: Whether to clip the heading (small shift) and trailing
|
|
||||||
(incomplete batch) instances.
|
|
||||||
:type clipped: bool
|
|
||||||
:return: Batch shuffled mainifest.
|
|
||||||
:rtype: list
|
|
||||||
"""
|
|
||||||
manifest.sort(key=lambda x: x["duration"])
|
|
||||||
shift_len = self._rng.randint(0, batch_size - 1)
|
|
||||||
batch_manifest = list(zip(* [iter(manifest[shift_len:])] * batch_size))
|
|
||||||
self._rng.shuffle(batch_manifest)
|
|
||||||
batch_manifest = [item for batch in batch_manifest for item in batch]
|
|
||||||
if not clipped:
|
|
||||||
res_len = len(manifest) - shift_len - len(batch_manifest)
|
|
||||||
batch_manifest.extend(manifest[-res_len:])
|
|
||||||
batch_manifest.extend(manifest[0:shift_len])
|
|
||||||
return batch_manifest
|
|
||||||
@ -1,13 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
@ -1,194 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains the audio featurizer class."""
|
|
||||||
import numpy as np
|
|
||||||
from python_speech_features import delta
|
|
||||||
from python_speech_features import mfcc
|
|
||||||
|
|
||||||
|
|
||||||
class AudioFeaturizer(object):
|
|
||||||
"""Audio featurizer, for extracting features from audio contents of
|
|
||||||
AudioSegment or SpeechSegment.
|
|
||||||
|
|
||||||
Currently, it supports feature types of linear spectrogram and mfcc.
|
|
||||||
|
|
||||||
:param specgram_type: Specgram feature type. Options: 'linear'.
|
|
||||||
:type specgram_type: str
|
|
||||||
:param stride_ms: Striding size (in milliseconds) for generating frames.
|
|
||||||
:type stride_ms: float
|
|
||||||
:param window_ms: Window size (in milliseconds) for generating frames.
|
|
||||||
:type window_ms: float
|
|
||||||
:param max_freq: When specgram_type is 'linear', only FFT bins
|
|
||||||
corresponding to frequencies between [0, max_freq] are
|
|
||||||
returned; when specgram_type is 'mfcc', max_feq is the
|
|
||||||
highest band edge of mel filters.
|
|
||||||
:types max_freq: None|float
|
|
||||||
:param target_sample_rate: Audio are resampled (if upsampling or
|
|
||||||
downsampling is allowed) to this before
|
|
||||||
extracting spectrogram features.
|
|
||||||
:type target_sample_rate: float
|
|
||||||
:param use_dB_normalization: Whether to normalize the audio to a certain
|
|
||||||
decibels before extracting the features.
|
|
||||||
:type use_dB_normalization: bool
|
|
||||||
:param target_dB: Target audio decibels for normalization.
|
|
||||||
:type target_dB: float
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
specgram_type='linear',
|
|
||||||
stride_ms=10.0,
|
|
||||||
window_ms=20.0,
|
|
||||||
max_freq=None,
|
|
||||||
target_sample_rate=16000,
|
|
||||||
use_dB_normalization=True,
|
|
||||||
target_dB=-20):
|
|
||||||
self._specgram_type = specgram_type
|
|
||||||
self._stride_ms = stride_ms
|
|
||||||
self._window_ms = window_ms
|
|
||||||
self._max_freq = max_freq
|
|
||||||
self._target_sample_rate = target_sample_rate
|
|
||||||
self._use_dB_normalization = use_dB_normalization
|
|
||||||
self._target_dB = target_dB
|
|
||||||
|
|
||||||
def featurize(self,
|
|
||||||
audio_segment,
|
|
||||||
allow_downsampling=True,
|
|
||||||
allow_upsampling=True):
|
|
||||||
"""Extract audio features from AudioSegment or SpeechSegment.
|
|
||||||
|
|
||||||
:param audio_segment: Audio/speech segment to extract features from.
|
|
||||||
:type audio_segment: AudioSegment|SpeechSegment
|
|
||||||
:param allow_downsampling: Whether to allow audio downsampling before
|
|
||||||
featurizing.
|
|
||||||
:type allow_downsampling: bool
|
|
||||||
:param allow_upsampling: Whether to allow audio upsampling before
|
|
||||||
featurizing.
|
|
||||||
:type allow_upsampling: bool
|
|
||||||
:return: Spectrogram audio feature in 2darray.
|
|
||||||
:rtype: ndarray
|
|
||||||
:raises ValueError: If audio sample rate is not supported.
|
|
||||||
"""
|
|
||||||
# upsampling or downsampling
|
|
||||||
if ((audio_segment.sample_rate > self._target_sample_rate and
|
|
||||||
allow_downsampling) or
|
|
||||||
(audio_segment.sample_rate < self._target_sample_rate and
|
|
||||||
allow_upsampling)):
|
|
||||||
audio_segment.resample(self._target_sample_rate)
|
|
||||||
if audio_segment.sample_rate != self._target_sample_rate:
|
|
||||||
raise ValueError("Audio sample rate is not supported. "
|
|
||||||
"Turn allow_downsampling or allow up_sampling on.")
|
|
||||||
# decibel normalization
|
|
||||||
if self._use_dB_normalization:
|
|
||||||
audio_segment.normalize(target_db=self._target_dB)
|
|
||||||
# extract spectrogram
|
|
||||||
return self._compute_specgram(audio_segment.samples,
|
|
||||||
audio_segment.sample_rate)
|
|
||||||
|
|
||||||
def _compute_specgram(self, samples, sample_rate):
|
|
||||||
"""Extract various audio features."""
|
|
||||||
if self._specgram_type == 'linear':
|
|
||||||
return self._compute_linear_specgram(
|
|
||||||
samples, sample_rate, self._stride_ms, self._window_ms,
|
|
||||||
self._max_freq)
|
|
||||||
elif self._specgram_type == 'mfcc':
|
|
||||||
return self._compute_mfcc(samples, sample_rate, self._stride_ms,
|
|
||||||
self._window_ms, self._max_freq)
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown specgram_type %s. "
|
|
||||||
"Supported values: linear." % self._specgram_type)
|
|
||||||
|
|
||||||
def _compute_linear_specgram(self,
|
|
||||||
samples,
|
|
||||||
sample_rate,
|
|
||||||
stride_ms=10.0,
|
|
||||||
window_ms=20.0,
|
|
||||||
max_freq=None,
|
|
||||||
eps=1e-14):
|
|
||||||
"""Compute the linear spectrogram from FFT energy."""
|
|
||||||
if max_freq is None:
|
|
||||||
max_freq = sample_rate / 2
|
|
||||||
if max_freq > sample_rate / 2:
|
|
||||||
raise ValueError("max_freq must not be greater than half of "
|
|
||||||
"sample rate.")
|
|
||||||
if stride_ms > window_ms:
|
|
||||||
raise ValueError("Stride size must not be greater than "
|
|
||||||
"window size.")
|
|
||||||
stride_size = int(0.001 * sample_rate * stride_ms)
|
|
||||||
window_size = int(0.001 * sample_rate * window_ms)
|
|
||||||
specgram, freqs = self._specgram_real(
|
|
||||||
samples,
|
|
||||||
window_size=window_size,
|
|
||||||
stride_size=stride_size,
|
|
||||||
sample_rate=sample_rate)
|
|
||||||
ind = np.where(freqs <= max_freq)[0][-1] + 1
|
|
||||||
return np.log(specgram[:ind, :] + eps)
|
|
||||||
|
|
||||||
def _specgram_real(self, samples, window_size, stride_size, sample_rate):
|
|
||||||
"""Compute the spectrogram for samples from a real signal."""
|
|
||||||
# extract strided windows
|
|
||||||
truncate_size = (len(samples) - window_size) % stride_size
|
|
||||||
samples = samples[:len(samples) - truncate_size]
|
|
||||||
nshape = (window_size, (len(samples) - window_size) // stride_size + 1)
|
|
||||||
nstrides = (samples.strides[0], samples.strides[0] * stride_size)
|
|
||||||
windows = np.lib.stride_tricks.as_strided(
|
|
||||||
samples, shape=nshape, strides=nstrides)
|
|
||||||
assert np.all(
|
|
||||||
windows[:, 1] == samples[stride_size:(stride_size + window_size)])
|
|
||||||
# window weighting, squared Fast Fourier Transform (fft), scaling
|
|
||||||
weighting = np.hanning(window_size)[:, None]
|
|
||||||
fft = np.fft.rfft(windows * weighting, axis=0)
|
|
||||||
fft = np.absolute(fft)
|
|
||||||
fft = fft**2
|
|
||||||
scale = np.sum(weighting**2) * sample_rate
|
|
||||||
fft[1:-1, :] *= (2.0 / scale)
|
|
||||||
fft[(0, -1), :] /= scale
|
|
||||||
# prepare fft frequency list
|
|
||||||
freqs = float(sample_rate) / window_size * np.arange(fft.shape[0])
|
|
||||||
return fft, freqs
|
|
||||||
|
|
||||||
def _compute_mfcc(self,
|
|
||||||
samples,
|
|
||||||
sample_rate,
|
|
||||||
stride_ms=10.0,
|
|
||||||
window_ms=20.0,
|
|
||||||
max_freq=None):
|
|
||||||
"""Compute mfcc from samples."""
|
|
||||||
if max_freq is None:
|
|
||||||
max_freq = sample_rate / 2
|
|
||||||
if max_freq > sample_rate / 2:
|
|
||||||
raise ValueError("max_freq must not be greater than half of "
|
|
||||||
"sample rate.")
|
|
||||||
if stride_ms > window_ms:
|
|
||||||
raise ValueError("Stride size must not be greater than "
|
|
||||||
"window size.")
|
|
||||||
# compute the 13 cepstral coefficients, and the first one is replaced
|
|
||||||
# by log(frame energy)
|
|
||||||
mfcc_feat = mfcc(
|
|
||||||
signal=samples,
|
|
||||||
samplerate=sample_rate,
|
|
||||||
winlen=0.001 * window_ms,
|
|
||||||
winstep=0.001 * stride_ms,
|
|
||||||
highfreq=max_freq)
|
|
||||||
# Deltas
|
|
||||||
d_mfcc_feat = delta(mfcc_feat, 2)
|
|
||||||
# Deltas-Deltas
|
|
||||||
dd_mfcc_feat = delta(d_mfcc_feat, 2)
|
|
||||||
# transpose
|
|
||||||
mfcc_feat = np.transpose(mfcc_feat)
|
|
||||||
d_mfcc_feat = np.transpose(d_mfcc_feat)
|
|
||||||
dd_mfcc_feat = np.transpose(dd_mfcc_feat)
|
|
||||||
# concat above three features
|
|
||||||
concat_mfcc_feat = np.concatenate(
|
|
||||||
(mfcc_feat, d_mfcc_feat, dd_mfcc_feat))
|
|
||||||
return concat_mfcc_feat
|
|
||||||
@ -1,107 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains the speech featurizer class."""
|
|
||||||
from data_utils.featurizer.audio_featurizer import AudioFeaturizer
|
|
||||||
from data_utils.featurizer.text_featurizer import TextFeaturizer
|
|
||||||
|
|
||||||
|
|
||||||
class SpeechFeaturizer(object):
|
|
||||||
"""Speech featurizer, for extracting features from both audio and transcript
|
|
||||||
contents of SpeechSegment.
|
|
||||||
|
|
||||||
Currently, for audio parts, it supports feature types of linear
|
|
||||||
spectrogram and mfcc; for transcript parts, it only supports char-level
|
|
||||||
tokenizing and conversion into a list of token indices. Note that the
|
|
||||||
token indexing order follows the given vocabulary file.
|
|
||||||
|
|
||||||
:param vocab_filepath: Filepath to load vocabulary for token indices
|
|
||||||
conversion.
|
|
||||||
:type specgram_type: str
|
|
||||||
:param specgram_type: Specgram feature type. Options: 'linear', 'mfcc'.
|
|
||||||
:type specgram_type: str
|
|
||||||
:param stride_ms: Striding size (in milliseconds) for generating frames.
|
|
||||||
:type stride_ms: float
|
|
||||||
:param window_ms: Window size (in milliseconds) for generating frames.
|
|
||||||
:type window_ms: float
|
|
||||||
:param max_freq: When specgram_type is 'linear', only FFT bins
|
|
||||||
corresponding to frequencies between [0, max_freq] are
|
|
||||||
returned; when specgram_type is 'mfcc', max_freq is the
|
|
||||||
highest band edge of mel filters.
|
|
||||||
:types max_freq: None|float
|
|
||||||
:param target_sample_rate: Speech are resampled (if upsampling or
|
|
||||||
downsampling is allowed) to this before
|
|
||||||
extracting spectrogram features.
|
|
||||||
:type target_sample_rate: float
|
|
||||||
:param use_dB_normalization: Whether to normalize the audio to a certain
|
|
||||||
decibels before extracting the features.
|
|
||||||
:type use_dB_normalization: bool
|
|
||||||
:param target_dB: Target audio decibels for normalization.
|
|
||||||
:type target_dB: float
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
vocab_filepath,
|
|
||||||
specgram_type='linear',
|
|
||||||
stride_ms=10.0,
|
|
||||||
window_ms=20.0,
|
|
||||||
max_freq=None,
|
|
||||||
target_sample_rate=16000,
|
|
||||||
use_dB_normalization=True,
|
|
||||||
target_dB=-20):
|
|
||||||
self._audio_featurizer = AudioFeaturizer(
|
|
||||||
specgram_type=specgram_type,
|
|
||||||
stride_ms=stride_ms,
|
|
||||||
window_ms=window_ms,
|
|
||||||
max_freq=max_freq,
|
|
||||||
target_sample_rate=target_sample_rate,
|
|
||||||
use_dB_normalization=use_dB_normalization,
|
|
||||||
target_dB=target_dB)
|
|
||||||
self._text_featurizer = TextFeaturizer(vocab_filepath)
|
|
||||||
|
|
||||||
def featurize(self, speech_segment, keep_transcription_text):
|
|
||||||
"""Extract features for speech segment.
|
|
||||||
|
|
||||||
1. For audio parts, extract the audio features.
|
|
||||||
2. For transcript parts, keep the original text or convert text string
|
|
||||||
to a list of token indices in char-level.
|
|
||||||
|
|
||||||
:param audio_segment: Speech segment to extract features from.
|
|
||||||
:type audio_segment: SpeechSegment
|
|
||||||
:return: A tuple of 1) spectrogram audio feature in 2darray, 2) list of
|
|
||||||
char-level token indices.
|
|
||||||
:rtype: tuple
|
|
||||||
"""
|
|
||||||
audio_feature = self._audio_featurizer.featurize(speech_segment)
|
|
||||||
if keep_transcription_text:
|
|
||||||
return audio_feature, speech_segment.transcript
|
|
||||||
text_ids = self._text_featurizer.featurize(speech_segment.transcript)
|
|
||||||
return audio_feature, text_ids
|
|
||||||
|
|
||||||
@property
|
|
||||||
def vocab_size(self):
|
|
||||||
"""Return the vocabulary size.
|
|
||||||
|
|
||||||
:return: Vocabulary size.
|
|
||||||
:rtype: int
|
|
||||||
"""
|
|
||||||
return self._text_featurizer.vocab_size
|
|
||||||
|
|
||||||
@property
|
|
||||||
def vocab_list(self):
|
|
||||||
"""Return the vocabulary in list.
|
|
||||||
|
|
||||||
:return: Vocabulary in list.
|
|
||||||
:rtype: list
|
|
||||||
"""
|
|
||||||
return self._text_featurizer.vocab_list
|
|
||||||
@ -1,76 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains the text featurizer class."""
|
|
||||||
import codecs
|
|
||||||
|
|
||||||
|
|
||||||
class TextFeaturizer(object):
|
|
||||||
"""Text featurizer, for processing or extracting features from text.
|
|
||||||
|
|
||||||
Currently, it only supports char-level tokenizing and conversion into
|
|
||||||
a list of token indices. Note that the token indexing order follows the
|
|
||||||
given vocabulary file.
|
|
||||||
|
|
||||||
:param vocab_filepath: Filepath to load vocabulary for token indices
|
|
||||||
conversion.
|
|
||||||
:type specgram_type: str
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, vocab_filepath):
|
|
||||||
self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file(
|
|
||||||
vocab_filepath)
|
|
||||||
|
|
||||||
def featurize(self, text):
|
|
||||||
"""Convert text string to a list of token indices in char-level.Note
|
|
||||||
that the token indexing order follows the given vocabulary file.
|
|
||||||
|
|
||||||
:param text: Text to process.
|
|
||||||
:type text: str
|
|
||||||
:return: List of char-level token indices.
|
|
||||||
:rtype: list
|
|
||||||
"""
|
|
||||||
tokens = self._char_tokenize(text)
|
|
||||||
return [self._vocab_dict[token] for token in tokens]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def vocab_size(self):
|
|
||||||
"""Return the vocabulary size.
|
|
||||||
|
|
||||||
:return: Vocabulary size.
|
|
||||||
:rtype: int
|
|
||||||
"""
|
|
||||||
return len(self._vocab_list)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def vocab_list(self):
|
|
||||||
"""Return the vocabulary in list.
|
|
||||||
|
|
||||||
:return: Vocabulary in list.
|
|
||||||
:rtype: list
|
|
||||||
"""
|
|
||||||
return self._vocab_list
|
|
||||||
|
|
||||||
def _char_tokenize(self, text):
|
|
||||||
"""Character tokenizer."""
|
|
||||||
return list(text.strip())
|
|
||||||
|
|
||||||
def _load_vocabulary_from_file(self, vocab_filepath):
|
|
||||||
"""Load vocabulary from file."""
|
|
||||||
vocab_lines = []
|
|
||||||
with codecs.open(vocab_filepath, 'r', 'utf-8') as file:
|
|
||||||
vocab_lines.extend(file.readlines())
|
|
||||||
vocab_list = [line[:-1] for line in vocab_lines]
|
|
||||||
vocab_dict = dict(
|
|
||||||
[(token, id) for (id, token) in enumerate(vocab_list)])
|
|
||||||
return vocab_dict, vocab_list
|
|
||||||
@ -1,97 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains feature normalizers."""
|
|
||||||
import random
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from data_utils.audio import AudioSegment
|
|
||||||
from data_utils.utility import read_manifest
|
|
||||||
|
|
||||||
|
|
||||||
class FeatureNormalizer(object):
|
|
||||||
"""Feature normalizer. Normalize features to be of zero mean and unit
|
|
||||||
stddev.
|
|
||||||
|
|
||||||
if mean_std_filepath is provided (not None), the normalizer will directly
|
|
||||||
initilize from the file. Otherwise, both manifest_path and featurize_func
|
|
||||||
should be given for on-the-fly mean and stddev computing.
|
|
||||||
|
|
||||||
:param mean_std_filepath: File containing the pre-computed mean and stddev.
|
|
||||||
:type mean_std_filepath: None|str
|
|
||||||
:param manifest_path: Manifest of instances for computing mean and stddev.
|
|
||||||
:type meanifest_path: None|str
|
|
||||||
:param featurize_func: Function to extract features. It should be callable
|
|
||||||
with ``featurize_func(audio_segment)``.
|
|
||||||
:type featurize_func: None|callable
|
|
||||||
:param num_samples: Number of random samples for computing mean and stddev.
|
|
||||||
:type num_samples: int
|
|
||||||
:param random_seed: Random seed for sampling instances.
|
|
||||||
:type random_seed: int
|
|
||||||
:raises ValueError: If both mean_std_filepath and manifest_path
|
|
||||||
(or both mean_std_filepath and featurize_func) are None.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
mean_std_filepath,
|
|
||||||
manifest_path=None,
|
|
||||||
featurize_func=None,
|
|
||||||
num_samples=500,
|
|
||||||
random_seed=0):
|
|
||||||
if not mean_std_filepath:
|
|
||||||
if not (manifest_path and featurize_func):
|
|
||||||
raise ValueError("If mean_std_filepath is None, meanifest_path "
|
|
||||||
"and featurize_func should not be None.")
|
|
||||||
self._rng = random.Random(random_seed)
|
|
||||||
self._compute_mean_std(manifest_path, featurize_func, num_samples)
|
|
||||||
else:
|
|
||||||
self._read_mean_std_from_file(mean_std_filepath)
|
|
||||||
|
|
||||||
def apply(self, features, eps=1e-14):
|
|
||||||
"""Normalize features to be of zero mean and unit stddev.
|
|
||||||
|
|
||||||
:param features: Input features to be normalized.
|
|
||||||
:type features: ndarray
|
|
||||||
:param eps: added to stddev to provide numerical stablibity.
|
|
||||||
:type eps: float
|
|
||||||
:return: Normalized features.
|
|
||||||
:rtype: ndarray
|
|
||||||
"""
|
|
||||||
return (features - self._mean) / (self._std + eps)
|
|
||||||
|
|
||||||
def write_to_file(self, filepath):
|
|
||||||
"""Write the mean and stddev to the file.
|
|
||||||
|
|
||||||
:param filepath: File to write mean and stddev.
|
|
||||||
:type filepath: str
|
|
||||||
"""
|
|
||||||
np.savez(filepath, mean=self._mean, std=self._std)
|
|
||||||
|
|
||||||
def _read_mean_std_from_file(self, filepath):
|
|
||||||
"""Load mean and std from file."""
|
|
||||||
npzfile = np.load(filepath)
|
|
||||||
self._mean = npzfile["mean"]
|
|
||||||
self._std = npzfile["std"]
|
|
||||||
|
|
||||||
def _compute_mean_std(self, manifest_path, featurize_func, num_samples):
|
|
||||||
"""Compute mean and std from randomly sampled instances."""
|
|
||||||
manifest = read_manifest(manifest_path)
|
|
||||||
sampled_manifest = self._rng.sample(manifest, num_samples)
|
|
||||||
features = []
|
|
||||||
for instance in sampled_manifest:
|
|
||||||
features.append(
|
|
||||||
featurize_func(
|
|
||||||
AudioSegment.from_file(instance["audio_filepath"])))
|
|
||||||
features = np.hstack(features)
|
|
||||||
self._mean = np.mean(features, axis=1).reshape([-1, 1])
|
|
||||||
self._std = np.std(features, axis=1).reshape([-1, 1])
|
|
||||||
@ -1,153 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains the speech segment class."""
|
|
||||||
import numpy as np
|
|
||||||
from data_utils.audio import AudioSegment
|
|
||||||
|
|
||||||
|
|
||||||
class SpeechSegment(AudioSegment):
|
|
||||||
"""Speech segment abstraction, a subclass of AudioSegment,
|
|
||||||
with an additional transcript.
|
|
||||||
|
|
||||||
:param samples: Audio samples [num_samples x num_channels].
|
|
||||||
:type samples: ndarray.float32
|
|
||||||
:param sample_rate: Audio sample rate.
|
|
||||||
:type sample_rate: int
|
|
||||||
:param transcript: Transcript text for the speech.
|
|
||||||
:type transript: str
|
|
||||||
:raises TypeError: If the sample data type is not float or int.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, samples, sample_rate, transcript):
|
|
||||||
AudioSegment.__init__(self, samples, sample_rate)
|
|
||||||
self._transcript = transcript
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
"""Return whether two objects are equal.
|
|
||||||
"""
|
|
||||||
if not AudioSegment.__eq__(self, other):
|
|
||||||
return False
|
|
||||||
if self._transcript != other._transcript:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def __ne__(self, other):
|
|
||||||
"""Return whether two objects are unequal."""
|
|
||||||
return not self.__eq__(other)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_file(cls, filepath, transcript):
|
|
||||||
"""Create speech segment from audio file and corresponding transcript.
|
|
||||||
|
|
||||||
:param filepath: Filepath or file object to audio file.
|
|
||||||
:type filepath: str|file
|
|
||||||
:param transcript: Transcript text for the speech.
|
|
||||||
:type transript: str
|
|
||||||
:return: Speech segment instance.
|
|
||||||
:rtype: SpeechSegment
|
|
||||||
"""
|
|
||||||
audio = AudioSegment.from_file(filepath)
|
|
||||||
return cls(audio.samples, audio.sample_rate, transcript)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_bytes(cls, bytes, transcript):
|
|
||||||
"""Create speech segment from a byte string and corresponding
|
|
||||||
transcript.
|
|
||||||
|
|
||||||
:param bytes: Byte string containing audio samples.
|
|
||||||
:type bytes: str
|
|
||||||
:param transcript: Transcript text for the speech.
|
|
||||||
:type transript: str
|
|
||||||
:return: Speech segment instance.
|
|
||||||
:rtype: Speech Segment
|
|
||||||
"""
|
|
||||||
audio = AudioSegment.from_bytes(bytes)
|
|
||||||
return cls(audio.samples, audio.sample_rate, transcript)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def concatenate(cls, *segments):
|
|
||||||
"""Concatenate an arbitrary number of speech segments together, both
|
|
||||||
audio and transcript will be concatenated.
|
|
||||||
|
|
||||||
:param *segments: Input speech segments to be concatenated.
|
|
||||||
:type *segments: tuple of SpeechSegment
|
|
||||||
:return: Speech segment instance.
|
|
||||||
:rtype: SpeechSegment
|
|
||||||
:raises ValueError: If the number of segments is zero, or if the
|
|
||||||
sample_rate of any two segments does not match.
|
|
||||||
:raises TypeError: If any segment is not SpeechSegment instance.
|
|
||||||
"""
|
|
||||||
if len(segments) == 0:
|
|
||||||
raise ValueError("No speech segments are given to concatenate.")
|
|
||||||
sample_rate = segments[0]._sample_rate
|
|
||||||
transcripts = ""
|
|
||||||
for seg in segments:
|
|
||||||
if sample_rate != seg._sample_rate:
|
|
||||||
raise ValueError("Can't concatenate segments with "
|
|
||||||
"different sample rates")
|
|
||||||
if type(seg) is not cls:
|
|
||||||
raise TypeError("Only speech segments of the same type "
|
|
||||||
"instance can be concatenated.")
|
|
||||||
transcripts += seg._transcript
|
|
||||||
samples = np.concatenate([seg.samples for seg in segments])
|
|
||||||
return cls(samples, sample_rate, transcripts)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def slice_from_file(cls, filepath, transcript, start=None, end=None):
|
|
||||||
"""Loads a small section of an speech without having to load
|
|
||||||
the entire file into the memory which can be incredibly wasteful.
|
|
||||||
|
|
||||||
:param filepath: Filepath or file object to audio file.
|
|
||||||
:type filepath: str|file
|
|
||||||
:param start: Start time in seconds. If start is negative, it wraps
|
|
||||||
around from the end. If not provided, this function
|
|
||||||
reads from the very beginning.
|
|
||||||
:type start: float
|
|
||||||
:param end: End time in seconds. If end is negative, it wraps around
|
|
||||||
from the end. If not provided, the default behvaior is
|
|
||||||
to read to the end of the file.
|
|
||||||
:type end: float
|
|
||||||
:param transcript: Transcript text for the speech. if not provided,
|
|
||||||
the defaults is an empty string.
|
|
||||||
:type transript: str
|
|
||||||
:return: SpeechSegment instance of the specified slice of the input
|
|
||||||
speech file.
|
|
||||||
:rtype: SpeechSegment
|
|
||||||
"""
|
|
||||||
audio = AudioSegment.slice_from_file(filepath, start, end)
|
|
||||||
return cls(audio.samples, audio.sample_rate, transcript)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def make_silence(cls, duration, sample_rate):
|
|
||||||
"""Creates a silent speech segment of the given duration and
|
|
||||||
sample rate, transcript will be an empty string.
|
|
||||||
|
|
||||||
:param duration: Length of silence in seconds.
|
|
||||||
:type duration: float
|
|
||||||
:param sample_rate: Sample rate.
|
|
||||||
:type sample_rate: float
|
|
||||||
:return: Silence of the given duration.
|
|
||||||
:rtype: SpeechSegment
|
|
||||||
"""
|
|
||||||
audio = AudioSegment.make_silence(duration, sample_rate)
|
|
||||||
return cls(audio.samples, audio.sample_rate, "")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def transcript(self):
|
|
||||||
"""Return the transcript text.
|
|
||||||
|
|
||||||
:return: Transcript text for the speech.
|
|
||||||
:rtype: str
|
|
||||||
"""
|
|
||||||
return self._transcript
|
|
||||||
@ -1,98 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains data helper functions."""
|
|
||||||
import codecs
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import tarfile
|
|
||||||
|
|
||||||
from paddle.dataset.common import md5file
|
|
||||||
|
|
||||||
|
|
||||||
def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0):
|
|
||||||
"""Load and parse manifest file.
|
|
||||||
|
|
||||||
Instances with durations outside [min_duration, max_duration] will be
|
|
||||||
filtered out.
|
|
||||||
|
|
||||||
:param manifest_path: Manifest file to load and parse.
|
|
||||||
:type manifest_path: str
|
|
||||||
:param max_duration: Maximal duration in seconds for instance filter.
|
|
||||||
:type max_duration: float
|
|
||||||
:param min_duration: Minimal duration in seconds for instance filter.
|
|
||||||
:type min_duration: float
|
|
||||||
:return: Manifest parsing results. List of dict.
|
|
||||||
:rtype: list
|
|
||||||
:raises IOError: If failed to parse the manifest.
|
|
||||||
"""
|
|
||||||
manifest = []
|
|
||||||
for json_line in codecs.open(manifest_path, 'r', 'utf-8'):
|
|
||||||
try:
|
|
||||||
json_data = json.loads(json_line)
|
|
||||||
except Exception as e:
|
|
||||||
raise IOError("Error reading manifest: %s" % str(e))
|
|
||||||
if (json_data["duration"] <= max_duration and
|
|
||||||
json_data["duration"] >= min_duration):
|
|
||||||
manifest.append(json_data)
|
|
||||||
return manifest
|
|
||||||
|
|
||||||
|
|
||||||
def getfile_insensitive(path):
|
|
||||||
"""Get the actual file path when given insensitive filename."""
|
|
||||||
directory, filename = os.path.split(path)
|
|
||||||
directory, filename = (directory or '.'), filename.lower()
|
|
||||||
for f in os.listdir(directory):
|
|
||||||
newpath = os.path.join(directory, f)
|
|
||||||
if os.path.isfile(newpath) and f.lower() == filename:
|
|
||||||
return newpath
|
|
||||||
|
|
||||||
|
|
||||||
def download_multi(url, target_dir, extra_args):
|
|
||||||
"""Download multiple files from url to target_dir."""
|
|
||||||
if not os.path.exists(target_dir):
|
|
||||||
os.makedirs(target_dir)
|
|
||||||
print("Downloading %s ..." % url)
|
|
||||||
ret_code = os.system("wget -c " + url + ' ' + extra_args + " -P " +
|
|
||||||
target_dir)
|
|
||||||
return ret_code
|
|
||||||
|
|
||||||
|
|
||||||
def download(url, md5sum, target_dir):
|
|
||||||
"""Download file from url to target_dir, and check md5sum."""
|
|
||||||
if not os.path.exists(target_dir):
|
|
||||||
os.makedirs(target_dir)
|
|
||||||
filepath = os.path.join(target_dir, url.split("/")[-1])
|
|
||||||
if not (os.path.exists(filepath) and md5file(filepath) == md5sum):
|
|
||||||
print("Downloading %s ..." % url)
|
|
||||||
os.system("wget -c " + url + " -P " + target_dir)
|
|
||||||
print("\nMD5 Chesksum %s ..." % filepath)
|
|
||||||
if not md5file(filepath) == md5sum:
|
|
||||||
raise RuntimeError("MD5 checksum failed.")
|
|
||||||
else:
|
|
||||||
print("File exists, skip downloading. (%s)" % filepath)
|
|
||||||
return filepath
|
|
||||||
|
|
||||||
|
|
||||||
def unpack(filepath, target_dir, rm_tar=False):
|
|
||||||
"""Unpack the file to the target_dir."""
|
|
||||||
print("Unpacking %s ..." % filepath)
|
|
||||||
tar = tarfile.open(filepath)
|
|
||||||
tar.extractall(target_dir)
|
|
||||||
tar.close()
|
|
||||||
if rm_tar is True:
|
|
||||||
os.remove(filepath)
|
|
||||||
|
|
||||||
|
|
||||||
class XmapEndSignal():
|
|
||||||
pass
|
|
||||||
@ -1,3 +0,0 @@
|
|||||||
# Reference
|
|
||||||
* [Sequence Modeling With CTC](https://distill.pub/2017/ctc/)
|
|
||||||
* [First-Pass Large Vocabulary Continuous Speech Recognition using Bi-Directional Recurrent DNNs](https://arxiv.org/pdf/1408.2873.pdf)
|
|
||||||
@ -1,13 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
@ -1,248 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains various CTC decoders."""
|
|
||||||
import multiprocessing
|
|
||||||
from itertools import groupby
|
|
||||||
from math import log
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def ctc_greedy_decoder(probs_seq, vocabulary):
|
|
||||||
"""CTC greedy (best path) decoder.
|
|
||||||
|
|
||||||
Path consisting of the most probable tokens are further post-processed to
|
|
||||||
remove consecutive repetitions and all blanks.
|
|
||||||
|
|
||||||
:param probs_seq: 2-D list of probabilities over the vocabulary for each
|
|
||||||
character. Each element is a list of float probabilities
|
|
||||||
for one character.
|
|
||||||
:type probs_seq: list
|
|
||||||
:param vocabulary: Vocabulary list.
|
|
||||||
:type vocabulary: list
|
|
||||||
:return: Decoding result string.
|
|
||||||
:rtype: baseline
|
|
||||||
"""
|
|
||||||
# dimension verification
|
|
||||||
for probs in probs_seq:
|
|
||||||
if not len(probs) == len(vocabulary) + 1:
|
|
||||||
raise ValueError("probs_seq dimension mismatchedd with vocabulary")
|
|
||||||
# argmax to get the best index for each time step
|
|
||||||
max_index_list = list(np.array(probs_seq).argmax(axis=1))
|
|
||||||
# remove consecutive duplicate indexes
|
|
||||||
index_list = [index_group[0] for index_group in groupby(max_index_list)]
|
|
||||||
# remove blank indexes
|
|
||||||
blank_index = len(vocabulary)
|
|
||||||
index_list = [index for index in index_list if index != blank_index]
|
|
||||||
# convert index list to string
|
|
||||||
return ''.join([vocabulary[index] for index in index_list])
|
|
||||||
|
|
||||||
|
|
||||||
def ctc_beam_search_decoder(probs_seq,
|
|
||||||
beam_size,
|
|
||||||
vocabulary,
|
|
||||||
cutoff_prob=1.0,
|
|
||||||
cutoff_top_n=40,
|
|
||||||
ext_scoring_func=None,
|
|
||||||
nproc=False):
|
|
||||||
"""CTC Beam search decoder.
|
|
||||||
|
|
||||||
It utilizes beam search to approximately select top best decoding
|
|
||||||
labels and returning results in the descending order.
|
|
||||||
The implementation is based on Prefix Beam Search
|
|
||||||
(https://arxiv.org/abs/1408.2873), and the unclear part is
|
|
||||||
redesigned. Two important modifications: 1) in the iterative computation
|
|
||||||
of probabilities, the assignment operation is changed to accumulation for
|
|
||||||
one prefix may comes from different paths; 2) the if condition "if l^+ not
|
|
||||||
in A_prev then" after probabilities' computation is deprecated for it is
|
|
||||||
hard to understand and seems unnecessary.
|
|
||||||
|
|
||||||
:param probs_seq: 2-D list of probability distributions over each time
|
|
||||||
step, with each element being a list of normalized
|
|
||||||
probabilities over vocabulary and blank.
|
|
||||||
:type probs_seq: 2-D list
|
|
||||||
:param beam_size: Width for beam search.
|
|
||||||
:type beam_size: int
|
|
||||||
:param vocabulary: Vocabulary list.
|
|
||||||
:type vocabulary: list
|
|
||||||
:param cutoff_prob: Cutoff probability in pruning,
|
|
||||||
default 1.0, no pruning.
|
|
||||||
:type cutoff_prob: float
|
|
||||||
:param ext_scoring_func: External scoring function for
|
|
||||||
partially decoded sentence, e.g. word count
|
|
||||||
or language model.
|
|
||||||
:type external_scoring_func: callable
|
|
||||||
:param nproc: Whether the decoder used in multiprocesses.
|
|
||||||
:type nproc: bool
|
|
||||||
:return: List of tuples of log probability and sentence as decoding
|
|
||||||
results, in descending order of the probability.
|
|
||||||
:rtype: list
|
|
||||||
"""
|
|
||||||
# dimension check
|
|
||||||
for prob_list in probs_seq:
|
|
||||||
if not len(prob_list) == len(vocabulary) + 1:
|
|
||||||
raise ValueError("The shape of prob_seq does not match with the "
|
|
||||||
"shape of the vocabulary.")
|
|
||||||
|
|
||||||
# blank_id assign
|
|
||||||
blank_id = len(vocabulary)
|
|
||||||
|
|
||||||
# If the decoder called in the multiprocesses, then use the global scorer
|
|
||||||
# instantiated in ctc_beam_search_decoder_batch().
|
|
||||||
if nproc is True:
|
|
||||||
global ext_nproc_scorer
|
|
||||||
ext_scoring_func = ext_nproc_scorer
|
|
||||||
|
|
||||||
# initialize
|
|
||||||
# prefix_set_prev: the set containing selected prefixes
|
|
||||||
# probs_b_prev: prefixes' probability ending with blank in previous step
|
|
||||||
# probs_nb_prev: prefixes' probability ending with non-blank in previous step
|
|
||||||
prefix_set_prev = {'\t': 1.0}
|
|
||||||
probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0}
|
|
||||||
|
|
||||||
# extend prefix in loop
|
|
||||||
for time_step in range(len(probs_seq)):
|
|
||||||
# prefix_set_next: the set containing candidate prefixes
|
|
||||||
# probs_b_cur: prefixes' probability ending with blank in current step
|
|
||||||
# probs_nb_cur: prefixes' probability ending with non-blank in current step
|
|
||||||
prefix_set_next, probs_b_cur, probs_nb_cur = {}, {}, {}
|
|
||||||
|
|
||||||
prob_idx = list(enumerate(probs_seq[time_step]))
|
|
||||||
cutoff_len = len(prob_idx)
|
|
||||||
# If pruning is enabled
|
|
||||||
if cutoff_prob < 1.0 or cutoff_top_n < cutoff_len:
|
|
||||||
prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
|
|
||||||
cutoff_len, cum_prob = 0, 0.0
|
|
||||||
for i in range(len(prob_idx)):
|
|
||||||
cum_prob += prob_idx[i][1]
|
|
||||||
cutoff_len += 1
|
|
||||||
if cum_prob >= cutoff_prob:
|
|
||||||
break
|
|
||||||
cutoff_len = min(cutoff_len, cutoff_top_n)
|
|
||||||
prob_idx = prob_idx[0:cutoff_len]
|
|
||||||
|
|
||||||
for l in prefix_set_prev:
|
|
||||||
if l not in prefix_set_next:
|
|
||||||
probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0
|
|
||||||
|
|
||||||
# extend prefix by travering prob_idx
|
|
||||||
for index in range(cutoff_len):
|
|
||||||
c, prob_c = prob_idx[index][0], prob_idx[index][1]
|
|
||||||
|
|
||||||
if c == blank_id:
|
|
||||||
probs_b_cur[l] += prob_c * (
|
|
||||||
probs_b_prev[l] + probs_nb_prev[l])
|
|
||||||
else:
|
|
||||||
last_char = l[-1]
|
|
||||||
new_char = vocabulary[c]
|
|
||||||
l_plus = l + new_char
|
|
||||||
if l_plus not in prefix_set_next:
|
|
||||||
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0
|
|
||||||
|
|
||||||
if new_char == last_char:
|
|
||||||
probs_nb_cur[l_plus] += prob_c * probs_b_prev[l]
|
|
||||||
probs_nb_cur[l] += prob_c * probs_nb_prev[l]
|
|
||||||
elif new_char == ' ':
|
|
||||||
if (ext_scoring_func is None) or (len(l) == 1):
|
|
||||||
score = 1.0
|
|
||||||
else:
|
|
||||||
prefix = l[1:]
|
|
||||||
score = ext_scoring_func(prefix)
|
|
||||||
probs_nb_cur[l_plus] += score * prob_c * (
|
|
||||||
probs_b_prev[l] + probs_nb_prev[l])
|
|
||||||
else:
|
|
||||||
probs_nb_cur[l_plus] += prob_c * (
|
|
||||||
probs_b_prev[l] + probs_nb_prev[l])
|
|
||||||
# add l_plus into prefix_set_next
|
|
||||||
prefix_set_next[l_plus] = probs_nb_cur[
|
|
||||||
l_plus] + probs_b_cur[l_plus]
|
|
||||||
# add l into prefix_set_next
|
|
||||||
prefix_set_next[l] = probs_b_cur[l] + probs_nb_cur[l]
|
|
||||||
# update probs
|
|
||||||
probs_b_prev, probs_nb_prev = probs_b_cur, probs_nb_cur
|
|
||||||
|
|
||||||
# store top beam_size prefixes
|
|
||||||
prefix_set_prev = sorted(
|
|
||||||
prefix_set_next.items(), key=lambda asd: asd[1], reverse=True)
|
|
||||||
if beam_size < len(prefix_set_prev):
|
|
||||||
prefix_set_prev = prefix_set_prev[:beam_size]
|
|
||||||
prefix_set_prev = dict(prefix_set_prev)
|
|
||||||
|
|
||||||
beam_result = []
|
|
||||||
for seq, prob in prefix_set_prev.items():
|
|
||||||
if prob > 0.0 and len(seq) > 1:
|
|
||||||
result = seq[1:]
|
|
||||||
# score last word by external scorer
|
|
||||||
if (ext_scoring_func is not None) and (result[-1] != ' '):
|
|
||||||
prob = prob * ext_scoring_func(result)
|
|
||||||
log_prob = log(prob)
|
|
||||||
beam_result.append((log_prob, result))
|
|
||||||
else:
|
|
||||||
beam_result.append((float('-inf'), ''))
|
|
||||||
|
|
||||||
# output top beam_size decoding results
|
|
||||||
beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
|
|
||||||
return beam_result
|
|
||||||
|
|
||||||
|
|
||||||
def ctc_beam_search_decoder_batch(probs_split,
|
|
||||||
beam_size,
|
|
||||||
vocabulary,
|
|
||||||
num_processes,
|
|
||||||
cutoff_prob=1.0,
|
|
||||||
cutoff_top_n=40,
|
|
||||||
ext_scoring_func=None):
|
|
||||||
"""CTC beam search decoder using multiple processes.
|
|
||||||
|
|
||||||
:param probs_seq: 3-D list with each element as an instance of 2-D list
|
|
||||||
of probabilities used by ctc_beam_search_decoder().
|
|
||||||
:type probs_seq: 3-D list
|
|
||||||
:param beam_size: Width for beam search.
|
|
||||||
:type beam_size: int
|
|
||||||
:param vocabulary: Vocabulary list.
|
|
||||||
:type vocabulary: list
|
|
||||||
:param num_processes: Number of parallel processes.
|
|
||||||
:type num_processes: int
|
|
||||||
:param cutoff_prob: Cutoff probability in pruning,
|
|
||||||
default 1.0, no pruning.
|
|
||||||
:type cutoff_prob: float
|
|
||||||
:param num_processes: Number of parallel processes.
|
|
||||||
:type num_processes: int
|
|
||||||
:param ext_scoring_func: External scoring function for
|
|
||||||
partially decoded sentence, e.g. word count
|
|
||||||
or language model.
|
|
||||||
:type external_scoring_function: callable
|
|
||||||
:return: List of tuples of log probability and sentence as decoding
|
|
||||||
results, in descending order of the probability.
|
|
||||||
:rtype: list
|
|
||||||
"""
|
|
||||||
if not num_processes > 0:
|
|
||||||
raise ValueError("Number of processes must be positive!")
|
|
||||||
|
|
||||||
# use global variable to pass the externnal scorer to beam search decoder
|
|
||||||
global ext_nproc_scorer
|
|
||||||
ext_nproc_scorer = ext_scoring_func
|
|
||||||
nproc = True
|
|
||||||
|
|
||||||
pool = multiprocessing.Pool(processes=num_processes)
|
|
||||||
results = []
|
|
||||||
for i, probs_list in enumerate(probs_split):
|
|
||||||
args = (probs_list, beam_size, vocabulary, cutoff_prob, cutoff_top_n,
|
|
||||||
None, nproc)
|
|
||||||
results.append(pool.apply_async(ctc_beam_search_decoder, args))
|
|
||||||
|
|
||||||
pool.close()
|
|
||||||
pool.join()
|
|
||||||
beam_search_results = [result.get() for result in results]
|
|
||||||
return beam_search_results
|
|
||||||
@ -1,78 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""External Scorer for Beam Search Decoder."""
|
|
||||||
import os
|
|
||||||
|
|
||||||
import kenlm
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
class Scorer(object):
|
|
||||||
"""External scorer to evaluate a prefix or whole sentence in
|
|
||||||
beam search decoding, including the score from n-gram language
|
|
||||||
model and word count.
|
|
||||||
|
|
||||||
:param alpha: Parameter associated with language model. Don't use
|
|
||||||
language model when alpha = 0.
|
|
||||||
:type alpha: float
|
|
||||||
:param beta: Parameter associated with word count. Don't use word
|
|
||||||
count when beta = 0.
|
|
||||||
:type beta: float
|
|
||||||
:model_path: Path to load language model.
|
|
||||||
:type model_path: str
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, alpha, beta, model_path):
|
|
||||||
self._alpha = alpha
|
|
||||||
self._beta = beta
|
|
||||||
if not os.path.isfile(model_path):
|
|
||||||
raise IOError("Invaid language model path: %s" % model_path)
|
|
||||||
self._language_model = kenlm.LanguageModel(model_path)
|
|
||||||
|
|
||||||
# n-gram language model scoring
|
|
||||||
def _language_model_score(self, sentence):
|
|
||||||
#log10 prob of last word
|
|
||||||
log_cond_prob = list(
|
|
||||||
self._language_model.full_scores(sentence, eos=False))[-1][0]
|
|
||||||
return np.power(10, log_cond_prob)
|
|
||||||
|
|
||||||
# word insertion term
|
|
||||||
def _word_count(self, sentence):
|
|
||||||
words = sentence.strip().split(' ')
|
|
||||||
return len(words)
|
|
||||||
|
|
||||||
# reset alpha and beta
|
|
||||||
def reset_params(self, alpha, beta):
|
|
||||||
self._alpha = alpha
|
|
||||||
self._beta = beta
|
|
||||||
|
|
||||||
# execute evaluation
|
|
||||||
def __call__(self, sentence, log=False):
|
|
||||||
"""Evaluation function, gathering all the different scores
|
|
||||||
and return the final one.
|
|
||||||
|
|
||||||
:param sentence: The input sentence for evalutation
|
|
||||||
:type sentence: str
|
|
||||||
:param log: Whether return the score in log representation.
|
|
||||||
:type log: bool
|
|
||||||
:return: Evaluation score, in the decimal or log.
|
|
||||||
:rtype: float
|
|
||||||
"""
|
|
||||||
lm = self._language_model_score(sentence)
|
|
||||||
word_cnt = self._word_count(sentence)
|
|
||||||
if log is False:
|
|
||||||
score = np.power(lm, self._alpha) * np.power(word_cnt, self._beta)
|
|
||||||
else:
|
|
||||||
score = self._alpha * np.log(lm) + self._beta * np.log(word_cnt)
|
|
||||||
return score
|
|
||||||
@ -1,134 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Wrapper for various CTC decoders in SWIG."""
|
|
||||||
import swig_decoders
|
|
||||||
|
|
||||||
|
|
||||||
class Scorer(swig_decoders.Scorer):
|
|
||||||
"""Wrapper for Scorer.
|
|
||||||
|
|
||||||
:param alpha: Parameter associated with language model. Don't use
|
|
||||||
language model when alpha = 0.
|
|
||||||
:type alpha: float
|
|
||||||
:param beta: Parameter associated with word count. Don't use word
|
|
||||||
count when beta = 0.
|
|
||||||
:type beta: float
|
|
||||||
:model_path: Path to load language model.
|
|
||||||
:type model_path: str
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, alpha, beta, model_path, vocabulary):
|
|
||||||
swig_decoders.Scorer.__init__(self, alpha, beta, model_path, vocabulary)
|
|
||||||
|
|
||||||
|
|
||||||
def ctc_greedy_decoder(probs_seq, vocabulary, blank_id):
|
|
||||||
"""Wrapper for ctc best path decoder in swig.
|
|
||||||
|
|
||||||
:param probs_seq: 2-D list of probability distributions over each time
|
|
||||||
step, with each element being a list of normalized
|
|
||||||
probabilities over vocabulary and blank.
|
|
||||||
:type probs_seq: 2-D list
|
|
||||||
:param vocabulary: Vocabulary list.
|
|
||||||
:type vocabulary: list
|
|
||||||
:return: Decoding result string.
|
|
||||||
:rtype: str
|
|
||||||
"""
|
|
||||||
result = swig_decoders.ctc_greedy_decoder(probs_seq.tolist(), vocabulary,
|
|
||||||
blank_id)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def ctc_beam_search_decoder(probs_seq,
|
|
||||||
vocabulary,
|
|
||||||
beam_size,
|
|
||||||
cutoff_prob=1.0,
|
|
||||||
cutoff_top_n=40,
|
|
||||||
ext_scoring_func=None,
|
|
||||||
blank_id=0):
|
|
||||||
"""Wrapper for the CTC Beam Search Decoder.
|
|
||||||
|
|
||||||
:param probs_seq: 2-D list of probability distributions over each time
|
|
||||||
step, with each element being a list of normalized
|
|
||||||
probabilities over vocabulary and blank.
|
|
||||||
:type probs_seq: 2-D list
|
|
||||||
:param vocabulary: Vocabulary list.
|
|
||||||
:type vocabulary: list
|
|
||||||
:param beam_size: Width for beam search.
|
|
||||||
:type beam_size: int
|
|
||||||
:param cutoff_prob: Cutoff probability in pruning,
|
|
||||||
default 1.0, no pruning.
|
|
||||||
:type cutoff_prob: float
|
|
||||||
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
|
|
||||||
characters with highest probs in vocabulary will be
|
|
||||||
used in beam search, default 40.
|
|
||||||
:type cutoff_top_n: int
|
|
||||||
:param ext_scoring_func: External scoring function for
|
|
||||||
partially decoded sentence, e.g. word count
|
|
||||||
or language model.
|
|
||||||
:type external_scoring_func: callable
|
|
||||||
:return: List of tuples of log probability and sentence as decoding
|
|
||||||
results, in descending order of the probability.
|
|
||||||
:rtype: list
|
|
||||||
"""
|
|
||||||
beam_results = swig_decoders.ctc_beam_search_decoder(
|
|
||||||
probs_seq.tolist(), vocabulary, beam_size, cutoff_prob, cutoff_top_n,
|
|
||||||
ext_scoring_func, blank_id)
|
|
||||||
beam_results = [(res[0], res[1].decode('utf-8')) for res in beam_results]
|
|
||||||
return beam_results
|
|
||||||
|
|
||||||
|
|
||||||
def ctc_beam_search_decoder_batch(probs_split,
|
|
||||||
vocabulary,
|
|
||||||
beam_size,
|
|
||||||
num_processes,
|
|
||||||
cutoff_prob=1.0,
|
|
||||||
cutoff_top_n=40,
|
|
||||||
ext_scoring_func=None,
|
|
||||||
blank_id=0):
|
|
||||||
"""Wrapper for the batched CTC beam search decoder.
|
|
||||||
|
|
||||||
:param probs_seq: 3-D list with each element as an instance of 2-D list
|
|
||||||
of probabilities used by ctc_beam_search_decoder().
|
|
||||||
:type probs_seq: 3-D list
|
|
||||||
:param vocabulary: Vocabulary list.
|
|
||||||
:type vocabulary: list
|
|
||||||
:param beam_size: Width for beam search.
|
|
||||||
:type beam_size: int
|
|
||||||
:param num_processes: Number of parallel processes.
|
|
||||||
:type num_processes: int
|
|
||||||
:param cutoff_prob: Cutoff probability in vocabulary pruning,
|
|
||||||
default 1.0, no pruning.
|
|
||||||
:type cutoff_prob: float
|
|
||||||
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
|
|
||||||
characters with highest probs in vocabulary will be
|
|
||||||
used in beam search, default 40.
|
|
||||||
:type cutoff_top_n: int
|
|
||||||
:param num_processes: Number of parallel processes.
|
|
||||||
:type num_processes: int
|
|
||||||
:param ext_scoring_func: External scoring function for
|
|
||||||
partially decoded sentence, e.g. word count
|
|
||||||
or language model.
|
|
||||||
:type external_scoring_function: callable
|
|
||||||
:return: List of tuples of log probability and sentence as decoding
|
|
||||||
results, in descending order of the probability.
|
|
||||||
:rtype: list
|
|
||||||
"""
|
|
||||||
probs_split = [probs_seq.tolist() for probs_seq in probs_split]
|
|
||||||
|
|
||||||
batch_beam_results = swig_decoders.ctc_beam_search_decoder_batch(
|
|
||||||
probs_split, vocabulary, beam_size, num_processes, cutoff_prob,
|
|
||||||
cutoff_top_n, ext_scoring_func, blank_id)
|
|
||||||
batch_beam_results = [[(res[0], res[1]) for res in beam_results]
|
|
||||||
for beam_results in batch_beam_results]
|
|
||||||
return batch_beam_results
|
|
||||||
@ -1,13 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
@ -1,721 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains the audio segment class."""
|
|
||||||
import copy
|
|
||||||
import io
|
|
||||||
import random
|
|
||||||
import re
|
|
||||||
import struct
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import resampy
|
|
||||||
import soundfile
|
|
||||||
import soxbindings as sox
|
|
||||||
from scipy import signal
|
|
||||||
|
|
||||||
|
|
||||||
class AudioSegment(object):
|
|
||||||
"""Monaural audio segment abstraction.
|
|
||||||
|
|
||||||
:param samples: Audio samples [num_samples x num_channels].
|
|
||||||
:type samples: ndarray.float32
|
|
||||||
:param sample_rate: Audio sample rate.
|
|
||||||
:type sample_rate: int
|
|
||||||
:raises TypeError: If the sample data type is not float or int.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, samples, sample_rate):
|
|
||||||
"""Create audio segment from samples.
|
|
||||||
|
|
||||||
Samples are convert float32 internally, with int scaled to [-1, 1].
|
|
||||||
"""
|
|
||||||
self._samples = self._convert_samples_to_float32(samples)
|
|
||||||
self._sample_rate = sample_rate
|
|
||||||
if self._samples.ndim >= 2:
|
|
||||||
self._samples = np.mean(self._samples, 1)
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
"""Return whether two objects are equal."""
|
|
||||||
if type(other) is not type(self):
|
|
||||||
return False
|
|
||||||
if self._sample_rate != other._sample_rate:
|
|
||||||
return False
|
|
||||||
if self._samples.shape != other._samples.shape:
|
|
||||||
return False
|
|
||||||
if np.any(self.samples != other._samples):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def __ne__(self, other):
|
|
||||||
"""Return whether two objects are unequal."""
|
|
||||||
return not self.__eq__(other)
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
"""Return human-readable representation of segment."""
|
|
||||||
return ("%s: num_samples=%d, sample_rate=%d, duration=%.2fsec, "
|
|
||||||
"rms=%.2fdB" % (type(self), self.num_samples, self.sample_rate,
|
|
||||||
self.duration, self.rms_db))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_file(cls, file):
|
|
||||||
"""Create audio segment from audio file.
|
|
||||||
|
|
||||||
:param filepath: Filepath or file object to audio file.
|
|
||||||
:type filepath: str|file
|
|
||||||
:return: Audio segment instance.
|
|
||||||
:rtype: AudioSegment
|
|
||||||
"""
|
|
||||||
if isinstance(file, str) and re.findall(r".seqbin_\d+$", file):
|
|
||||||
return cls.from_sequence_file(file)
|
|
||||||
else:
|
|
||||||
samples, sample_rate = soundfile.read(file, dtype='float32')
|
|
||||||
return cls(samples, sample_rate)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def slice_from_file(cls, file, start=None, end=None):
|
|
||||||
"""Loads a small section of an audio without having to load
|
|
||||||
the entire file into the memory which can be incredibly wasteful.
|
|
||||||
|
|
||||||
:param file: Input audio filepath or file object.
|
|
||||||
:type file: str|file
|
|
||||||
:param start: Start time in seconds. If start is negative, it wraps
|
|
||||||
around from the end. If not provided, this function
|
|
||||||
reads from the very beginning.
|
|
||||||
:type start: float
|
|
||||||
:param end: End time in seconds. If end is negative, it wraps around
|
|
||||||
from the end. If not provided, the default behvaior is
|
|
||||||
to read to the end of the file.
|
|
||||||
:type end: float
|
|
||||||
:return: AudioSegment instance of the specified slice of the input
|
|
||||||
audio file.
|
|
||||||
:rtype: AudioSegment
|
|
||||||
:raise ValueError: If start or end is incorrectly set, e.g. out of
|
|
||||||
bounds in time.
|
|
||||||
"""
|
|
||||||
sndfile = soundfile.SoundFile(file)
|
|
||||||
sample_rate = sndfile.samplerate
|
|
||||||
duration = float(len(sndfile)) / sample_rate
|
|
||||||
start = 0. if start is None else start
|
|
||||||
end = duration if end is None else end
|
|
||||||
if start < 0.0:
|
|
||||||
start += duration
|
|
||||||
if end < 0.0:
|
|
||||||
end += duration
|
|
||||||
if start < 0.0:
|
|
||||||
raise ValueError("The slice start position (%f s) is out of "
|
|
||||||
"bounds." % start)
|
|
||||||
if end < 0.0:
|
|
||||||
raise ValueError("The slice end position (%f s) is out of bounds." %
|
|
||||||
end)
|
|
||||||
if start > end:
|
|
||||||
raise ValueError("The slice start position (%f s) is later than "
|
|
||||||
"the slice end position (%f s)." % (start, end))
|
|
||||||
if end > duration:
|
|
||||||
raise ValueError("The slice end position (%f s) is out of bounds "
|
|
||||||
"(> %f s)" % (end, duration))
|
|
||||||
start_frame = int(start * sample_rate)
|
|
||||||
end_frame = int(end * sample_rate)
|
|
||||||
sndfile.seek(start_frame)
|
|
||||||
data = sndfile.read(frames=end_frame - start_frame, dtype='float32')
|
|
||||||
return cls(data, sample_rate)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_sequence_file(cls, filepath):
|
|
||||||
"""Create audio segment from sequence file. Sequence file is a binary
|
|
||||||
file containing a collection of multiple audio files, with several
|
|
||||||
header bytes in the head indicating the offsets of each audio byte data
|
|
||||||
chunk.
|
|
||||||
|
|
||||||
The format is:
|
|
||||||
|
|
||||||
4 bytes (int, version),
|
|
||||||
4 bytes (int, num of utterance),
|
|
||||||
4 bytes (int, bytes per header),
|
|
||||||
[bytes_per_header*(num_utterance+1)] bytes (offsets for each audio),
|
|
||||||
audio_bytes_data_of_1st_utterance,
|
|
||||||
audio_bytes_data_of_2nd_utterance,
|
|
||||||
......
|
|
||||||
|
|
||||||
Sequence file name must end with ".seqbin". And the filename of the 5th
|
|
||||||
utterance's audio file in sequence file "xxx.seqbin" must be
|
|
||||||
"xxx.seqbin_5", with "5" indicating the utterance index within this
|
|
||||||
sequence file (starting from 1).
|
|
||||||
|
|
||||||
:param filepath: Filepath of sequence file.
|
|
||||||
:type filepath: str
|
|
||||||
:return: Audio segment instance.
|
|
||||||
:rtype: AudioSegment
|
|
||||||
"""
|
|
||||||
# parse filepath
|
|
||||||
matches = re.match(r"(.+\.seqbin)_(\d+)", filepath)
|
|
||||||
if matches is None:
|
|
||||||
raise IOError("File type of %s is not supported" % filepath)
|
|
||||||
filename = matches.group(1)
|
|
||||||
fileno = int(matches.group(2))
|
|
||||||
|
|
||||||
# read headers
|
|
||||||
f = io.open(filename, mode='rb', encoding='utf8')
|
|
||||||
version = f.read(4)
|
|
||||||
num_utterances = struct.unpack("i", f.read(4))[0]
|
|
||||||
bytes_per_header = struct.unpack("i", f.read(4))[0]
|
|
||||||
header_bytes = f.read(bytes_per_header * (num_utterances + 1))
|
|
||||||
header = [
|
|
||||||
struct.unpack("i", header_bytes[bytes_per_header * i:
|
|
||||||
bytes_per_header * (i + 1)])[0]
|
|
||||||
for i in range(num_utterances + 1)
|
|
||||||
]
|
|
||||||
|
|
||||||
# read audio bytes
|
|
||||||
f.seek(header[fileno - 1])
|
|
||||||
audio_bytes = f.read(header[fileno] - header[fileno - 1])
|
|
||||||
f.close()
|
|
||||||
|
|
||||||
# create audio segment
|
|
||||||
try:
|
|
||||||
return cls.from_bytes(audio_bytes)
|
|
||||||
except Exception as e:
|
|
||||||
samples = np.frombuffer(audio_bytes, dtype='int16')
|
|
||||||
return cls(samples=samples, sample_rate=8000)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_bytes(cls, bytes):
|
|
||||||
"""Create audio segment from a byte string containing audio samples.
|
|
||||||
|
|
||||||
:param bytes: Byte string containing audio samples.
|
|
||||||
:type bytes: str
|
|
||||||
:return: Audio segment instance.
|
|
||||||
:rtype: AudioSegment
|
|
||||||
"""
|
|
||||||
samples, sample_rate = soundfile.read(
|
|
||||||
io.BytesIO(bytes), dtype='float32')
|
|
||||||
return cls(samples, sample_rate)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def concatenate(cls, *segments):
|
|
||||||
"""Concatenate an arbitrary number of audio segments together.
|
|
||||||
|
|
||||||
:param *segments: Input audio segments to be concatenated.
|
|
||||||
:type *segments: tuple of AudioSegment
|
|
||||||
:return: Audio segment instance as concatenating results.
|
|
||||||
:rtype: AudioSegment
|
|
||||||
:raises ValueError: If the number of segments is zero, or if the
|
|
||||||
sample_rate of any segments does not match.
|
|
||||||
:raises TypeError: If any segment is not AudioSegment instance.
|
|
||||||
"""
|
|
||||||
# Perform basic sanity-checks.
|
|
||||||
if len(segments) == 0:
|
|
||||||
raise ValueError("No audio segments are given to concatenate.")
|
|
||||||
sample_rate = segments[0]._sample_rate
|
|
||||||
for seg in segments:
|
|
||||||
if sample_rate != seg._sample_rate:
|
|
||||||
raise ValueError("Can't concatenate segments with "
|
|
||||||
"different sample rates")
|
|
||||||
if type(seg) is not cls:
|
|
||||||
raise TypeError("Only audio segments of the same type "
|
|
||||||
"can be concatenated.")
|
|
||||||
samples = np.concatenate([seg.samples for seg in segments])
|
|
||||||
return cls(samples, sample_rate)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def make_silence(cls, duration, sample_rate):
|
|
||||||
"""Creates a silent audio segment of the given duration and sample rate.
|
|
||||||
|
|
||||||
:param duration: Length of silence in seconds.
|
|
||||||
:type duration: float
|
|
||||||
:param sample_rate: Sample rate.
|
|
||||||
:type sample_rate: float
|
|
||||||
:return: Silent AudioSegment instance of the given duration.
|
|
||||||
:rtype: AudioSegment
|
|
||||||
"""
|
|
||||||
samples = np.zeros(int(duration * sample_rate))
|
|
||||||
return cls(samples, sample_rate)
|
|
||||||
|
|
||||||
def to_wav_file(self, filepath, dtype='float32'):
|
|
||||||
"""Save audio segment to disk as wav file.
|
|
||||||
|
|
||||||
:param filepath: WAV filepath or file object to save the
|
|
||||||
audio segment.
|
|
||||||
:type filepath: str|file
|
|
||||||
:param dtype: Subtype for audio file. Options: 'int16', 'int32',
|
|
||||||
'float32', 'float64'. Default is 'float32'.
|
|
||||||
:type dtype: str
|
|
||||||
:raises TypeError: If dtype is not supported.
|
|
||||||
"""
|
|
||||||
samples = self._convert_samples_from_float32(self._samples, dtype)
|
|
||||||
subtype_map = {
|
|
||||||
'int16': 'PCM_16',
|
|
||||||
'int32': 'PCM_32',
|
|
||||||
'float32': 'FLOAT',
|
|
||||||
'float64': 'DOUBLE'
|
|
||||||
}
|
|
||||||
soundfile.write(
|
|
||||||
filepath,
|
|
||||||
samples,
|
|
||||||
self._sample_rate,
|
|
||||||
format='WAV',
|
|
||||||
subtype=subtype_map[dtype])
|
|
||||||
|
|
||||||
def superimpose(self, other):
|
|
||||||
"""Add samples from another segment to those of this segment
|
|
||||||
(sample-wise addition, not segment concatenation).
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param other: Segment containing samples to be added in.
|
|
||||||
:type other: AudioSegments
|
|
||||||
:raise TypeError: If type of two segments don't match.
|
|
||||||
:raise ValueError: If the sample rates of the two segments are not
|
|
||||||
equal, or if the lengths of segments don't match.
|
|
||||||
"""
|
|
||||||
if isinstance(other, type(self)):
|
|
||||||
raise TypeError("Cannot add segments of different types: %s "
|
|
||||||
"and %s." % (type(self), type(other)))
|
|
||||||
if self._sample_rate != other._sample_rate:
|
|
||||||
raise ValueError("Sample rates must match to add segments.")
|
|
||||||
if len(self._samples) != len(other._samples):
|
|
||||||
raise ValueError("Segment lengths must match to add segments.")
|
|
||||||
self._samples += other._samples
|
|
||||||
|
|
||||||
def to_bytes(self, dtype='float32'):
|
|
||||||
"""Create a byte string containing the audio content.
|
|
||||||
|
|
||||||
:param dtype: Data type for export samples. Options: 'int16', 'int32',
|
|
||||||
'float32', 'float64'. Default is 'float32'.
|
|
||||||
:type dtype: str
|
|
||||||
:return: Byte string containing audio content.
|
|
||||||
:rtype: str
|
|
||||||
"""
|
|
||||||
samples = self._convert_samples_from_float32(self._samples, dtype)
|
|
||||||
return samples.tostring()
|
|
||||||
|
|
||||||
def to(self, dtype='int16'):
|
|
||||||
"""Create a `dtype` audio content.
|
|
||||||
|
|
||||||
:param dtype: Data type for export samples. Options: 'int16', 'int32',
|
|
||||||
'float32', 'float64'. Default is 'float32'.
|
|
||||||
:type dtype: str
|
|
||||||
:return: np.ndarray containing `dtype` audio content.
|
|
||||||
:rtype: str
|
|
||||||
"""
|
|
||||||
samples = self._convert_samples_from_float32(self._samples, dtype)
|
|
||||||
return samples
|
|
||||||
|
|
||||||
def gain_db(self, gain):
|
|
||||||
"""Apply gain in decibels to samples.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param gain: Gain in decibels to apply to samples.
|
|
||||||
:type gain: float|1darray
|
|
||||||
"""
|
|
||||||
self._samples *= 10.**(gain / 20.)
|
|
||||||
|
|
||||||
def change_speed(self, speed_rate):
|
|
||||||
"""Change the audio speed by linear interpolation.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param speed_rate: Rate of speed change:
|
|
||||||
speed_rate > 1.0, speed up the audio;
|
|
||||||
speed_rate = 1.0, unchanged;
|
|
||||||
speed_rate < 1.0, slow down the audio;
|
|
||||||
speed_rate <= 0.0, not allowed, raise ValueError.
|
|
||||||
:type speed_rate: float
|
|
||||||
:raises ValueError: If speed_rate <= 0.0.
|
|
||||||
"""
|
|
||||||
if speed_rate == 1.0:
|
|
||||||
return
|
|
||||||
if speed_rate <= 0:
|
|
||||||
raise ValueError("speed_rate should be greater than zero.")
|
|
||||||
|
|
||||||
# numpy
|
|
||||||
# old_length = self._samples.shape[0]
|
|
||||||
# new_length = int(old_length / speed_rate)
|
|
||||||
# old_indices = np.arange(old_length)
|
|
||||||
# new_indices = np.linspace(start=0, stop=old_length, num=new_length)
|
|
||||||
# self._samples = np.interp(new_indices, old_indices, self._samples)
|
|
||||||
|
|
||||||
# sox, slow
|
|
||||||
tfm = sox.Transformer()
|
|
||||||
tfm.set_globals(multithread=False)
|
|
||||||
tfm.speed(speed_rate)
|
|
||||||
self._samples = tfm.build_array(
|
|
||||||
input_array=self._samples,
|
|
||||||
sample_rate_in=self._sample_rate).squeeze(-1).astype(
|
|
||||||
np.float32).copy()
|
|
||||||
|
|
||||||
def normalize(self, target_db=-20, max_gain_db=300.0):
|
|
||||||
"""Normalize audio to be of the desired RMS value in decibels.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param target_db: Target RMS value in decibels. This value should be
|
|
||||||
less than 0.0 as 0.0 is full-scale audio.
|
|
||||||
:type target_db: float
|
|
||||||
:param max_gain_db: Max amount of gain in dB that can be applied for
|
|
||||||
normalization. This is to prevent nans when
|
|
||||||
attempting to normalize a signal consisting of
|
|
||||||
all zeros.
|
|
||||||
:type max_gain_db: float
|
|
||||||
:raises ValueError: If the required gain to normalize the segment to
|
|
||||||
the target_db value exceeds max_gain_db.
|
|
||||||
"""
|
|
||||||
gain = target_db - self.rms_db
|
|
||||||
if gain > max_gain_db:
|
|
||||||
raise ValueError(
|
|
||||||
"Unable to normalize segment to %f dB because the "
|
|
||||||
"the probable gain have exceeds max_gain_db (%f dB)" %
|
|
||||||
(target_db, max_gain_db))
|
|
||||||
self.gain_db(min(max_gain_db, target_db - self.rms_db))
|
|
||||||
|
|
||||||
def normalize_online_bayesian(self,
|
|
||||||
target_db,
|
|
||||||
prior_db,
|
|
||||||
prior_samples,
|
|
||||||
startup_delay=0.0):
|
|
||||||
"""Normalize audio using a production-compatible online/causal
|
|
||||||
algorithm. This uses an exponential likelihood and gamma prior to
|
|
||||||
make online estimates of the RMS even when there are very few samples.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param target_db: Target RMS value in decibels.
|
|
||||||
:type target_bd: float
|
|
||||||
:param prior_db: Prior RMS estimate in decibels.
|
|
||||||
:type prior_db: float
|
|
||||||
:param prior_samples: Prior strength in number of samples.
|
|
||||||
:type prior_samples: float
|
|
||||||
:param startup_delay: Default 0.0s. If provided, this function will
|
|
||||||
accrue statistics for the first startup_delay
|
|
||||||
seconds before applying online normalization.
|
|
||||||
:type startup_delay: float
|
|
||||||
"""
|
|
||||||
# Estimate total RMS online.
|
|
||||||
startup_sample_idx = min(self.num_samples - 1,
|
|
||||||
int(self.sample_rate * startup_delay))
|
|
||||||
prior_mean_squared = 10.**(prior_db / 10.)
|
|
||||||
prior_sum_of_squares = prior_mean_squared * prior_samples
|
|
||||||
cumsum_of_squares = np.cumsum(self.samples**2)
|
|
||||||
sample_count = np.arange(self.num_samples) + 1
|
|
||||||
if startup_sample_idx > 0:
|
|
||||||
cumsum_of_squares[:startup_sample_idx] = \
|
|
||||||
cumsum_of_squares[startup_sample_idx]
|
|
||||||
sample_count[:startup_sample_idx] = \
|
|
||||||
sample_count[startup_sample_idx]
|
|
||||||
mean_squared_estimate = ((cumsum_of_squares + prior_sum_of_squares) /
|
|
||||||
(sample_count + prior_samples))
|
|
||||||
rms_estimate_db = 10 * np.log10(mean_squared_estimate)
|
|
||||||
# Compute required time-varying gain.
|
|
||||||
gain_db = target_db - rms_estimate_db
|
|
||||||
self.gain_db(gain_db)
|
|
||||||
|
|
||||||
def resample(self, target_sample_rate, filter='kaiser_best'):
|
|
||||||
"""Resample the audio to a target sample rate.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param target_sample_rate: Target sample rate.
|
|
||||||
:type target_sample_rate: int
|
|
||||||
:param filter: The resampling filter to use one of {'kaiser_best',
|
|
||||||
'kaiser_fast'}.
|
|
||||||
:type filter: str
|
|
||||||
"""
|
|
||||||
self._samples = resampy.resample(
|
|
||||||
self.samples, self.sample_rate, target_sample_rate, filter=filter)
|
|
||||||
self._sample_rate = target_sample_rate
|
|
||||||
|
|
||||||
def pad_silence(self, duration, sides='both'):
|
|
||||||
"""Pad this audio sample with a period of silence.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param duration: Length of silence in seconds to pad.
|
|
||||||
:type duration: float
|
|
||||||
:param sides: Position for padding:
|
|
||||||
'beginning' - adds silence in the beginning;
|
|
||||||
'end' - adds silence in the end;
|
|
||||||
'both' - adds silence in both the beginning and the end.
|
|
||||||
:type sides: str
|
|
||||||
:raises ValueError: If sides is not supported.
|
|
||||||
"""
|
|
||||||
if duration == 0.0:
|
|
||||||
return self
|
|
||||||
cls = type(self)
|
|
||||||
silence = self.make_silence(duration, self._sample_rate)
|
|
||||||
if sides == "beginning":
|
|
||||||
padded = cls.concatenate(silence, self)
|
|
||||||
elif sides == "end":
|
|
||||||
padded = cls.concatenate(self, silence)
|
|
||||||
elif sides == "both":
|
|
||||||
padded = cls.concatenate(silence, self, silence)
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown value for the sides %s" % sides)
|
|
||||||
self._samples = padded._samples
|
|
||||||
|
|
||||||
def shift(self, shift_ms):
|
|
||||||
"""Shift the audio in time. If `shift_ms` is positive, shift with time
|
|
||||||
advance; if negative, shift with time delay. Silence are padded to
|
|
||||||
keep the duration unchanged.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param shift_ms: Shift time in millseconds. If positive, shift with
|
|
||||||
time advance; if negative; shift with time delay.
|
|
||||||
:type shift_ms: float
|
|
||||||
:raises ValueError: If shift_ms is longer than audio duration.
|
|
||||||
"""
|
|
||||||
if abs(shift_ms) / 1000.0 > self.duration:
|
|
||||||
raise ValueError("Absolute value of shift_ms should be smaller "
|
|
||||||
"than audio duration.")
|
|
||||||
shift_samples = int(shift_ms * self._sample_rate / 1000)
|
|
||||||
if shift_samples > 0:
|
|
||||||
# time advance
|
|
||||||
self._samples[:-shift_samples] = self._samples[shift_samples:]
|
|
||||||
self._samples[-shift_samples:] = 0
|
|
||||||
elif shift_samples < 0:
|
|
||||||
# time delay
|
|
||||||
self._samples[-shift_samples:] = self._samples[:shift_samples]
|
|
||||||
self._samples[:-shift_samples] = 0
|
|
||||||
|
|
||||||
def subsegment(self, start_sec=None, end_sec=None):
|
|
||||||
"""Cut the AudioSegment between given boundaries.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param start_sec: Beginning of subsegment in seconds.
|
|
||||||
:type start_sec: float
|
|
||||||
:param end_sec: End of subsegment in seconds.
|
|
||||||
:type end_sec: float
|
|
||||||
:raise ValueError: If start_sec or end_sec is incorrectly set, e.g. out
|
|
||||||
of bounds in time.
|
|
||||||
"""
|
|
||||||
start_sec = 0.0 if start_sec is None else start_sec
|
|
||||||
end_sec = self.duration if end_sec is None else end_sec
|
|
||||||
if start_sec < 0.0:
|
|
||||||
start_sec = self.duration + start_sec
|
|
||||||
if end_sec < 0.0:
|
|
||||||
end_sec = self.duration + end_sec
|
|
||||||
if start_sec < 0.0:
|
|
||||||
raise ValueError("The slice start position (%f s) is out of "
|
|
||||||
"bounds." % start_sec)
|
|
||||||
if end_sec < 0.0:
|
|
||||||
raise ValueError("The slice end position (%f s) is out of bounds." %
|
|
||||||
end_sec)
|
|
||||||
if start_sec > end_sec:
|
|
||||||
raise ValueError("The slice start position (%f s) is later than "
|
|
||||||
"the end position (%f s)." % (start_sec, end_sec))
|
|
||||||
if end_sec > self.duration:
|
|
||||||
raise ValueError("The slice end position (%f s) is out of bounds "
|
|
||||||
"(> %f s)" % (end_sec, self.duration))
|
|
||||||
start_sample = int(round(start_sec * self._sample_rate))
|
|
||||||
end_sample = int(round(end_sec * self._sample_rate))
|
|
||||||
self._samples = self._samples[start_sample:end_sample]
|
|
||||||
|
|
||||||
def random_subsegment(self, subsegment_length, rng=None):
|
|
||||||
"""Cut the specified length of the audiosegment randomly.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param subsegment_length: Subsegment length in seconds.
|
|
||||||
:type subsegment_length: float
|
|
||||||
:param rng: Random number generator state.
|
|
||||||
:type rng: random.Random
|
|
||||||
:raises ValueError: If the length of subsegment is greater than
|
|
||||||
the origineal segemnt.
|
|
||||||
"""
|
|
||||||
rng = random.Random() if rng is None else rng
|
|
||||||
if subsegment_length > self.duration:
|
|
||||||
raise ValueError("Length of subsegment must not be greater "
|
|
||||||
"than original segment.")
|
|
||||||
start_time = rng.uniform(0.0, self.duration - subsegment_length)
|
|
||||||
self.subsegment(start_time, start_time + subsegment_length)
|
|
||||||
|
|
||||||
def convolve(self, impulse_segment, allow_resample=False):
|
|
||||||
"""Convolve this audio segment with the given impulse segment.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param impulse_segment: Impulse response segments.
|
|
||||||
:type impulse_segment: AudioSegment
|
|
||||||
:param allow_resample: Indicates whether resampling is allowed when
|
|
||||||
the impulse_segment has a different sample
|
|
||||||
rate from this signal.
|
|
||||||
:type allow_resample: bool
|
|
||||||
:raises ValueError: If the sample rate is not match between two
|
|
||||||
audio segments when resample is not allowed.
|
|
||||||
"""
|
|
||||||
if allow_resample and self.sample_rate != impulse_segment.sample_rate:
|
|
||||||
impulse_segment.resample(self.sample_rate)
|
|
||||||
if self.sample_rate != impulse_segment.sample_rate:
|
|
||||||
raise ValueError("Impulse segment's sample rate (%d Hz) is not "
|
|
||||||
"equal to base signal sample rate (%d Hz)." %
|
|
||||||
(impulse_segment.sample_rate, self.sample_rate))
|
|
||||||
samples = signal.fftconvolve(self.samples, impulse_segment.samples,
|
|
||||||
"full")
|
|
||||||
self._samples = samples
|
|
||||||
|
|
||||||
def convolve_and_normalize(self, impulse_segment, allow_resample=False):
|
|
||||||
"""Convolve and normalize the resulting audio segment so that it
|
|
||||||
has the same average power as the input signal.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param impulse_segment: Impulse response segments.
|
|
||||||
:type impulse_segment: AudioSegment
|
|
||||||
:param allow_resample: Indicates whether resampling is allowed when
|
|
||||||
the impulse_segment has a different sample
|
|
||||||
rate from this signal.
|
|
||||||
:type allow_resample: bool
|
|
||||||
"""
|
|
||||||
target_db = self.rms_db
|
|
||||||
self.convolve(impulse_segment, allow_resample=allow_resample)
|
|
||||||
self.normalize(target_db)
|
|
||||||
|
|
||||||
def add_noise(self,
|
|
||||||
noise,
|
|
||||||
snr_dB,
|
|
||||||
allow_downsampling=False,
|
|
||||||
max_gain_db=300.0,
|
|
||||||
rng=None):
|
|
||||||
"""Add the given noise segment at a specific signal-to-noise ratio.
|
|
||||||
If the noise segment is longer than this segment, a random subsegment
|
|
||||||
of matching length is sampled from it and used instead.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param noise: Noise signal to add.
|
|
||||||
:type noise: AudioSegment
|
|
||||||
:param snr_dB: Signal-to-Noise Ratio, in decibels.
|
|
||||||
:type snr_dB: float
|
|
||||||
:param allow_downsampling: Whether to allow the noise signal to be
|
|
||||||
downsampled to match the base signal sample
|
|
||||||
rate.
|
|
||||||
:type allow_downsampling: bool
|
|
||||||
:param max_gain_db: Maximum amount of gain to apply to noise signal
|
|
||||||
before adding it in. This is to prevent attempting
|
|
||||||
to apply infinite gain to a zero signal.
|
|
||||||
:type max_gain_db: float
|
|
||||||
:param rng: Random number generator state.
|
|
||||||
:type rng: None|random.Random
|
|
||||||
:raises ValueError: If the sample rate does not match between the two
|
|
||||||
audio segments when downsampling is not allowed, or
|
|
||||||
if the duration of noise segments is shorter than
|
|
||||||
original audio segments.
|
|
||||||
"""
|
|
||||||
rng = random.Random() if rng is None else rng
|
|
||||||
if allow_downsampling and noise.sample_rate > self.sample_rate:
|
|
||||||
noise = noise.resample(self.sample_rate)
|
|
||||||
if noise.sample_rate != self.sample_rate:
|
|
||||||
raise ValueError("Noise sample rate (%d Hz) is not equal to base "
|
|
||||||
"signal sample rate (%d Hz)." % (noise.sample_rate,
|
|
||||||
self.sample_rate))
|
|
||||||
if noise.duration < self.duration:
|
|
||||||
raise ValueError("Noise signal (%f sec) must be at least as long as"
|
|
||||||
" base signal (%f sec)." %
|
|
||||||
(noise.duration, self.duration))
|
|
||||||
noise_gain_db = min(self.rms_db - noise.rms_db - snr_dB, max_gain_db)
|
|
||||||
noise_new = copy.deepcopy(noise)
|
|
||||||
noise_new.random_subsegment(self.duration, rng=rng)
|
|
||||||
noise_new.gain_db(noise_gain_db)
|
|
||||||
self.superimpose(noise_new)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def samples(self):
|
|
||||||
"""Return audio samples.
|
|
||||||
|
|
||||||
:return: Audio samples.
|
|
||||||
:rtype: ndarray
|
|
||||||
"""
|
|
||||||
return self._samples.copy()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sample_rate(self):
|
|
||||||
"""Return audio sample rate.
|
|
||||||
|
|
||||||
:return: Audio sample rate.
|
|
||||||
:rtype: int
|
|
||||||
"""
|
|
||||||
return self._sample_rate
|
|
||||||
|
|
||||||
@property
|
|
||||||
def num_samples(self):
|
|
||||||
"""Return number of samples.
|
|
||||||
|
|
||||||
:return: Number of samples.
|
|
||||||
:rtype: int
|
|
||||||
"""
|
|
||||||
return self._samples.shape[0]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def duration(self):
|
|
||||||
"""Return audio duration.
|
|
||||||
|
|
||||||
:return: Audio duration in seconds.
|
|
||||||
:rtype: float
|
|
||||||
"""
|
|
||||||
return self._samples.shape[0] / float(self._sample_rate)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def rms_db(self):
|
|
||||||
"""Return root mean square energy of the audio in decibels.
|
|
||||||
|
|
||||||
:return: Root mean square energy in decibels.
|
|
||||||
:rtype: float
|
|
||||||
"""
|
|
||||||
# square root => multiply by 10 instead of 20 for dBs
|
|
||||||
mean_square = np.mean(self._samples**2)
|
|
||||||
return 10 * np.log10(mean_square)
|
|
||||||
|
|
||||||
def _convert_samples_to_float32(self, samples):
|
|
||||||
"""Convert sample type to float32.
|
|
||||||
|
|
||||||
Audio sample type is usually integer or float-point.
|
|
||||||
Integers will be scaled to [-1, 1] in float32.
|
|
||||||
"""
|
|
||||||
float32_samples = samples.astype('float32')
|
|
||||||
if samples.dtype in np.sctypes['int']:
|
|
||||||
bits = np.iinfo(samples.dtype).bits
|
|
||||||
float32_samples *= (1. / 2**(bits - 1))
|
|
||||||
elif samples.dtype in np.sctypes['float']:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise TypeError("Unsupported sample type: %s." % samples.dtype)
|
|
||||||
return float32_samples
|
|
||||||
|
|
||||||
def _convert_samples_from_float32(self, samples, dtype):
|
|
||||||
"""Convert sample type from float32 to dtype.
|
|
||||||
|
|
||||||
Audio sample type is usually integer or float-point. For integer
|
|
||||||
type, float32 will be rescaled from [-1, 1] to the maximum range
|
|
||||||
supported by the integer type.
|
|
||||||
|
|
||||||
This is for writing a audio file.
|
|
||||||
"""
|
|
||||||
dtype = np.dtype(dtype)
|
|
||||||
output_samples = samples.copy()
|
|
||||||
if dtype in np.sctypes['int']:
|
|
||||||
bits = np.iinfo(dtype).bits
|
|
||||||
output_samples *= (2**(bits - 1) / 1.)
|
|
||||||
min_val = np.iinfo(dtype).min
|
|
||||||
max_val = np.iinfo(dtype).max
|
|
||||||
output_samples[output_samples > max_val] = max_val
|
|
||||||
output_samples[output_samples < min_val] = min_val
|
|
||||||
elif samples.dtype in np.sctypes['float']:
|
|
||||||
min_val = np.finfo(dtype).min
|
|
||||||
max_val = np.finfo(dtype).max
|
|
||||||
output_samples[output_samples > max_val] = max_val
|
|
||||||
output_samples[output_samples < min_val] = min_val
|
|
||||||
else:
|
|
||||||
raise TypeError("Unsupported sample type: %s." % samples.dtype)
|
|
||||||
return output_samples.astype(dtype)
|
|
||||||
@ -1,13 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
@ -1,218 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains the data augmentation pipeline."""
|
|
||||||
import json
|
|
||||||
from collections.abc import Sequence
|
|
||||||
from inspect import signature
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from deepspeech.frontend.augmentor.base import AugmentorBase
|
|
||||||
from deepspeech.utils.dynamic_import import dynamic_import
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
|
|
||||||
__all__ = ["AugmentationPipeline"]
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
import_alias = dict(
|
|
||||||
volume="deepspeech.frontend.augmentor.impulse_response:VolumePerturbAugmentor",
|
|
||||||
shift="deepspeech.frontend.augmentor.shift_perturb:ShiftPerturbAugmentor",
|
|
||||||
speed="deepspeech.frontend.augmentor.speed_perturb:SpeedPerturbAugmentor",
|
|
||||||
resample="deepspeech.frontend.augmentor.resample:ResampleAugmentor",
|
|
||||||
bayesian_normal="deepspeech.frontend.augmentor.online_bayesian_normalization:OnlineBayesianNormalizationAugmentor",
|
|
||||||
noise="deepspeech.frontend.augmentor.noise_perturb:NoisePerturbAugmentor",
|
|
||||||
impulse="deepspeech.frontend.augmentor.impulse_response:ImpulseResponseAugmentor",
|
|
||||||
specaug="deepspeech.frontend.augmentor.spec_augment:SpecAugmentor", )
|
|
||||||
|
|
||||||
|
|
||||||
class AugmentationPipeline():
|
|
||||||
"""Build a pre-processing pipeline with various augmentation models.Such a
|
|
||||||
data augmentation pipeline is oftern leveraged to augment the training
|
|
||||||
samples to make the model invariant to certain types of perturbations in the
|
|
||||||
real world, improving model's generalization ability.
|
|
||||||
|
|
||||||
The pipeline is built according the the augmentation configuration in json
|
|
||||||
string, e.g.
|
|
||||||
|
|
||||||
.. code-block::
|
|
||||||
|
|
||||||
[ {
|
|
||||||
"type": "noise",
|
|
||||||
"params": {"min_snr_dB": 10,
|
|
||||||
"max_snr_dB": 20,
|
|
||||||
"noise_manifest_path": "datasets/manifest.noise"},
|
|
||||||
"prob": 0.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "speed",
|
|
||||||
"params": {"min_speed_rate": 0.9,
|
|
||||||
"max_speed_rate": 1.1},
|
|
||||||
"prob": 1.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "shift",
|
|
||||||
"params": {"min_shift_ms": -5,
|
|
||||||
"max_shift_ms": 5},
|
|
||||||
"prob": 1.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "volume",
|
|
||||||
"params": {"min_gain_dBFS": -10,
|
|
||||||
"max_gain_dBFS": 10},
|
|
||||||
"prob": 0.0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "bayesian_normal",
|
|
||||||
"params": {"target_db": -20,
|
|
||||||
"prior_db": -20,
|
|
||||||
"prior_samples": 100},
|
|
||||||
"prob": 0.0
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
This augmentation configuration inserts two augmentation models
|
|
||||||
into the pipeline, with one is VolumePerturbAugmentor and the other
|
|
||||||
SpeedPerturbAugmentor. "prob" indicates the probability of the current
|
|
||||||
augmentor to take effect. If "prob" is zero, the augmentor does not take
|
|
||||||
effect.
|
|
||||||
|
|
||||||
Params:
|
|
||||||
augmentation_config(str): Augmentation configuration in json string.
|
|
||||||
random_seed(int): Random seed.
|
|
||||||
train(bool): whether is train mode.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the augmentation json config is in incorrect format".
|
|
||||||
"""
|
|
||||||
|
|
||||||
SPEC_TYPES = {'specaug'}
|
|
||||||
|
|
||||||
def __init__(self, augmentation_config: str, random_seed: int=0):
|
|
||||||
self._rng = np.random.RandomState(random_seed)
|
|
||||||
self.conf = {'mode': 'sequential', 'process': []}
|
|
||||||
if augmentation_config:
|
|
||||||
process = json.loads(augmentation_config)
|
|
||||||
self.conf['process'] += process
|
|
||||||
|
|
||||||
self._augmentors, self._rates = self._parse_pipeline_from('all')
|
|
||||||
self._audio_augmentors, self._audio_rates = self._parse_pipeline_from(
|
|
||||||
'audio')
|
|
||||||
self._spec_augmentors, self._spec_rates = self._parse_pipeline_from(
|
|
||||||
'feature')
|
|
||||||
|
|
||||||
def __call__(self, xs, uttid_list=None, **kwargs):
|
|
||||||
if not isinstance(xs, Sequence):
|
|
||||||
is_batch = False
|
|
||||||
xs = [xs]
|
|
||||||
else:
|
|
||||||
is_batch = True
|
|
||||||
|
|
||||||
if isinstance(uttid_list, str):
|
|
||||||
uttid_list = [uttid_list for _ in range(len(xs))]
|
|
||||||
|
|
||||||
if self.conf.get("mode", "sequential") == "sequential":
|
|
||||||
for idx, (func, rate) in enumerate(
|
|
||||||
zip(self._augmentors, self._rates), 0):
|
|
||||||
if self._rng.uniform(0., 1.) >= rate:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Derive only the args which the func has
|
|
||||||
try:
|
|
||||||
param = signature(func).parameters
|
|
||||||
except ValueError:
|
|
||||||
# Some function, e.g. built-in function, are failed
|
|
||||||
param = {}
|
|
||||||
_kwargs = {k: v for k, v in kwargs.items() if k in param}
|
|
||||||
|
|
||||||
try:
|
|
||||||
if uttid_list is not None and "uttid" in param:
|
|
||||||
xs = [
|
|
||||||
func(x, u, **_kwargs)
|
|
||||||
for x, u in zip(xs, uttid_list)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
xs = [func(x, **_kwargs) for x in xs]
|
|
||||||
except Exception:
|
|
||||||
logger.fatal("Catch a exception from {}th func: {}".format(
|
|
||||||
idx, func))
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Not supporting mode={}".format(self.conf["mode"]))
|
|
||||||
|
|
||||||
if is_batch:
|
|
||||||
return xs
|
|
||||||
else:
|
|
||||||
return xs[0]
|
|
||||||
|
|
||||||
def transform_audio(self, audio_segment):
|
|
||||||
"""Run the pre-processing pipeline for data augmentation.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param audio_segment: Audio segment to process.
|
|
||||||
:type audio_segment: AudioSegmenet|SpeechSegment
|
|
||||||
"""
|
|
||||||
for augmentor, rate in zip(self._audio_augmentors, self._audio_rates):
|
|
||||||
if self._rng.uniform(0., 1.) < rate:
|
|
||||||
augmentor.transform_audio(audio_segment)
|
|
||||||
|
|
||||||
def transform_feature(self, spec_segment):
|
|
||||||
"""spectrogram augmentation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
spec_segment (np.ndarray): audio feature, (D, T).
|
|
||||||
"""
|
|
||||||
for augmentor, rate in zip(self._spec_augmentors, self._spec_rates):
|
|
||||||
if self._rng.uniform(0., 1.) < rate:
|
|
||||||
spec_segment = augmentor.transform_feature(spec_segment)
|
|
||||||
return spec_segment
|
|
||||||
|
|
||||||
def _parse_pipeline_from(self, aug_type='all'):
|
|
||||||
"""Parse the config json to build a augmentation pipelien."""
|
|
||||||
assert aug_type in ('audio', 'feature', 'all'), aug_type
|
|
||||||
audio_confs = []
|
|
||||||
feature_confs = []
|
|
||||||
all_confs = []
|
|
||||||
for config in self.conf['process']:
|
|
||||||
all_confs.append(config)
|
|
||||||
if config["type"] in self.SPEC_TYPES:
|
|
||||||
feature_confs.append(config)
|
|
||||||
else:
|
|
||||||
audio_confs.append(config)
|
|
||||||
|
|
||||||
if aug_type == 'audio':
|
|
||||||
aug_confs = audio_confs
|
|
||||||
elif aug_type == 'feature':
|
|
||||||
aug_confs = feature_confs
|
|
||||||
else:
|
|
||||||
aug_confs = all_confs
|
|
||||||
|
|
||||||
augmentors = [
|
|
||||||
self._get_augmentor(config["type"], config["params"])
|
|
||||||
for config in aug_confs
|
|
||||||
]
|
|
||||||
rates = [config["prob"] for config in aug_confs]
|
|
||||||
return augmentors, rates
|
|
||||||
|
|
||||||
def _get_augmentor(self, augmentor_type, params):
|
|
||||||
"""Return an augmentation model by the type name, and pass in params."""
|
|
||||||
class_obj = dynamic_import(augmentor_type, import_alias)
|
|
||||||
assert issubclass(class_obj, AugmentorBase)
|
|
||||||
try:
|
|
||||||
obj = class_obj(self._rng, **params)
|
|
||||||
except Exception:
|
|
||||||
raise ValueError("Unknown augmentor type [%s]." % augmentor_type)
|
|
||||||
return obj
|
|
||||||
@ -1,59 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains the abstract base class for augmentation models."""
|
|
||||||
from abc import ABCMeta
|
|
||||||
from abc import abstractmethod
|
|
||||||
|
|
||||||
|
|
||||||
class AugmentorBase():
|
|
||||||
"""Abstract base class for augmentation model (augmentor) class.
|
|
||||||
All augmentor classes should inherit from this class, and implement the
|
|
||||||
following abstract methods.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__metaclass__ = ABCMeta
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __call__(self, xs):
|
|
||||||
raise NotImplementedError("AugmentorBase: Not impl __call__")
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def transform_audio(self, audio_segment):
|
|
||||||
"""Adds various effects to the input audio segment. Such effects
|
|
||||||
will augment the training data to make the model invariant to certain
|
|
||||||
types of perturbations in the real world, improving model's
|
|
||||||
generalization ability.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param audio_segment: Audio segment to add effects to.
|
|
||||||
:type audio_segment: AudioSegmenet|SpeechSegment
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("AugmentorBase: Not impl transform_audio")
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def transform_feature(self, spec_segment):
|
|
||||||
"""Adds various effects to the input audo feature segment. Such effects
|
|
||||||
will augment the training data to make the model invariant to certain
|
|
||||||
types of time_mask or freq_mask in the real world, improving model's
|
|
||||||
generalization ability.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
spec_segment (Spectrogram): Spectrogram segment to add effects to.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("AugmentorBase: Not impl transform_feature")
|
|
||||||
@ -1,50 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains the impulse response augmentation model."""
|
|
||||||
from deepspeech.frontend.audio import AudioSegment
|
|
||||||
from deepspeech.frontend.augmentor.base import AugmentorBase
|
|
||||||
from deepspeech.frontend.utility import read_manifest
|
|
||||||
|
|
||||||
|
|
||||||
class ImpulseResponseAugmentor(AugmentorBase):
|
|
||||||
"""Augmentation model for adding impulse response effect.
|
|
||||||
|
|
||||||
:param rng: Random generator object.
|
|
||||||
:type rng: random.Random
|
|
||||||
:param impulse_manifest_path: Manifest path for impulse audio data.
|
|
||||||
:type impulse_manifest_path: str
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, rng, impulse_manifest_path):
|
|
||||||
self._rng = rng
|
|
||||||
self._impulse_manifest = read_manifest(impulse_manifest_path)
|
|
||||||
|
|
||||||
def __call__(self, x, uttid=None, train=True):
|
|
||||||
if not train:
|
|
||||||
return x
|
|
||||||
self.transform_audio(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def transform_audio(self, audio_segment):
|
|
||||||
"""Add impulse response effect.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param audio_segment: Audio segment to add effects to.
|
|
||||||
:type audio_segment: AudioSegmenet|SpeechSegment
|
|
||||||
"""
|
|
||||||
impulse_json = self._rng.choice(
|
|
||||||
self._impulse_manifest, 1, replace=False)[0]
|
|
||||||
impulse_segment = AudioSegment.from_file(impulse_json['audio_filepath'])
|
|
||||||
audio_segment.convolve(impulse_segment, allow_resample=True)
|
|
||||||
@ -1,64 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains the noise perturb augmentation model."""
|
|
||||||
from deepspeech.frontend.audio import AudioSegment
|
|
||||||
from deepspeech.frontend.augmentor.base import AugmentorBase
|
|
||||||
from deepspeech.frontend.utility import read_manifest
|
|
||||||
|
|
||||||
|
|
||||||
class NoisePerturbAugmentor(AugmentorBase):
|
|
||||||
"""Augmentation model for adding background noise.
|
|
||||||
|
|
||||||
:param rng: Random generator object.
|
|
||||||
:type rng: random.Random
|
|
||||||
:param min_snr_dB: Minimal signal noise ratio, in decibels.
|
|
||||||
:type min_snr_dB: float
|
|
||||||
:param max_snr_dB: Maximal signal noise ratio, in decibels.
|
|
||||||
:type max_snr_dB: float
|
|
||||||
:param noise_manifest_path: Manifest path for noise audio data.
|
|
||||||
:type noise_manifest_path: str
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, rng, min_snr_dB, max_snr_dB, noise_manifest_path):
|
|
||||||
self._min_snr_dB = min_snr_dB
|
|
||||||
self._max_snr_dB = max_snr_dB
|
|
||||||
self._rng = rng
|
|
||||||
self._noise_manifest = read_manifest(manifest_path=noise_manifest_path)
|
|
||||||
|
|
||||||
def __call__(self, x, uttid=None, train=True):
|
|
||||||
if not train:
|
|
||||||
return x
|
|
||||||
self.transform_audio(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def transform_audio(self, audio_segment):
|
|
||||||
"""Add background noise audio.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param audio_segment: Audio segment to add effects to.
|
|
||||||
:type audio_segment: AudioSegmenet|SpeechSegment
|
|
||||||
"""
|
|
||||||
noise_json = self._rng.choice(self._noise_manifest, 1, replace=False)[0]
|
|
||||||
if noise_json['duration'] < audio_segment.duration:
|
|
||||||
raise RuntimeError("The duration of sampled noise audio is smaller "
|
|
||||||
"than the audio segment to add effects to.")
|
|
||||||
diff_duration = noise_json['duration'] - audio_segment.duration
|
|
||||||
start = self._rng.uniform(0, diff_duration)
|
|
||||||
end = start + audio_segment.duration
|
|
||||||
noise_segment = AudioSegment.slice_from_file(
|
|
||||||
noise_json['audio_filepath'], start=start, end=end)
|
|
||||||
snr_dB = self._rng.uniform(self._min_snr_dB, self._max_snr_dB)
|
|
||||||
audio_segment.add_noise(
|
|
||||||
noise_segment, snr_dB, allow_downsampling=True, rng=self._rng)
|
|
||||||
@ -1,63 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contain the online bayesian normalization augmentation model."""
|
|
||||||
from deepspeech.frontend.augmentor.base import AugmentorBase
|
|
||||||
|
|
||||||
|
|
||||||
class OnlineBayesianNormalizationAugmentor(AugmentorBase):
|
|
||||||
"""Augmentation model for adding online bayesian normalization.
|
|
||||||
|
|
||||||
:param rng: Random generator object.
|
|
||||||
:type rng: random.Random
|
|
||||||
:param target_db: Target RMS value in decibels.
|
|
||||||
:type target_db: float
|
|
||||||
:param prior_db: Prior RMS estimate in decibels.
|
|
||||||
:type prior_db: float
|
|
||||||
:param prior_samples: Prior strength in number of samples.
|
|
||||||
:type prior_samples: int
|
|
||||||
:param startup_delay: Default 0.0s. If provided, this function will
|
|
||||||
accrue statistics for the first startup_delay
|
|
||||||
seconds before applying online normalization.
|
|
||||||
:type starup_delay: float.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
rng,
|
|
||||||
target_db,
|
|
||||||
prior_db,
|
|
||||||
prior_samples,
|
|
||||||
startup_delay=0.0):
|
|
||||||
self._target_db = target_db
|
|
||||||
self._prior_db = prior_db
|
|
||||||
self._prior_samples = prior_samples
|
|
||||||
self._rng = rng
|
|
||||||
self._startup_delay = startup_delay
|
|
||||||
|
|
||||||
def __call__(self, x, uttid=None, train=True):
|
|
||||||
if not train:
|
|
||||||
return x
|
|
||||||
self.transform_audio(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def transform_audio(self, audio_segment):
|
|
||||||
"""Normalizes the input audio using the online Bayesian approach.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param audio_segment: Audio segment to add effects to.
|
|
||||||
:type audio_segment: AudioSegment|SpeechSegment
|
|
||||||
"""
|
|
||||||
audio_segment.normalize_online_bayesian(self._target_db, self._prior_db,
|
|
||||||
self._prior_samples,
|
|
||||||
self._startup_delay)
|
|
||||||
@ -1,48 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contain the resample augmentation model."""
|
|
||||||
from deepspeech.frontend.augmentor.base import AugmentorBase
|
|
||||||
|
|
||||||
|
|
||||||
class ResampleAugmentor(AugmentorBase):
|
|
||||||
"""Augmentation model for resampling.
|
|
||||||
|
|
||||||
See more info here:
|
|
||||||
https://ccrma.stanford.edu/~jos/resample/index.html
|
|
||||||
|
|
||||||
:param rng: Random generator object.
|
|
||||||
:type rng: random.Random
|
|
||||||
:param new_sample_rate: New sample rate in Hz.
|
|
||||||
:type new_sample_rate: int
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, rng, new_sample_rate):
|
|
||||||
self._new_sample_rate = new_sample_rate
|
|
||||||
self._rng = rng
|
|
||||||
|
|
||||||
def __call__(self, x, uttid=None, train=True):
|
|
||||||
if not train:
|
|
||||||
return x
|
|
||||||
self.transform_audio(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def transform_audio(self, audio_segment):
|
|
||||||
"""Resamples the input audio to a target sample rate.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param audio: Audio segment to add effects to.
|
|
||||||
:type audio: AudioSegment|SpeechSegment
|
|
||||||
"""
|
|
||||||
audio_segment.resample(self._new_sample_rate)
|
|
||||||
@ -1,49 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains the volume perturb augmentation model."""
|
|
||||||
from deepspeech.frontend.augmentor.base import AugmentorBase
|
|
||||||
|
|
||||||
|
|
||||||
class ShiftPerturbAugmentor(AugmentorBase):
|
|
||||||
"""Augmentation model for adding random shift perturbation.
|
|
||||||
|
|
||||||
:param rng: Random generator object.
|
|
||||||
:type rng: random.Random
|
|
||||||
:param min_shift_ms: Minimal shift in milliseconds.
|
|
||||||
:type min_shift_ms: float
|
|
||||||
:param max_shift_ms: Maximal shift in milliseconds.
|
|
||||||
:type max_shift_ms: float
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, rng, min_shift_ms, max_shift_ms):
|
|
||||||
self._min_shift_ms = min_shift_ms
|
|
||||||
self._max_shift_ms = max_shift_ms
|
|
||||||
self._rng = rng
|
|
||||||
|
|
||||||
def __call__(self, x, uttid=None, train=True):
|
|
||||||
if not train:
|
|
||||||
return x
|
|
||||||
self.transform_audio(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def transform_audio(self, audio_segment):
|
|
||||||
"""Shift audio.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param audio_segment: Audio segment to add effects to.
|
|
||||||
:type audio_segment: AudioSegmenet|SpeechSegment
|
|
||||||
"""
|
|
||||||
shift_ms = self._rng.uniform(self._min_shift_ms, self._max_shift_ms)
|
|
||||||
audio_segment.shift(shift_ms)
|
|
||||||
@ -1,256 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains the volume perturb augmentation model."""
|
|
||||||
import random
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
from PIL.Image import BICUBIC
|
|
||||||
|
|
||||||
from deepspeech.frontend.augmentor.base import AugmentorBase
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
|
|
||||||
class SpecAugmentor(AugmentorBase):
|
|
||||||
"""Augmentation model for Time warping, Frequency masking, Time masking.
|
|
||||||
|
|
||||||
SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
|
|
||||||
https://arxiv.org/abs/1904.08779
|
|
||||||
|
|
||||||
SpecAugment on Large Scale Datasets
|
|
||||||
https://arxiv.org/abs/1912.05533
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
rng,
|
|
||||||
F,
|
|
||||||
T,
|
|
||||||
n_freq_masks,
|
|
||||||
n_time_masks,
|
|
||||||
p=1.0,
|
|
||||||
W=40,
|
|
||||||
adaptive_number_ratio=0,
|
|
||||||
adaptive_size_ratio=0,
|
|
||||||
max_n_time_masks=20,
|
|
||||||
replace_with_zero=True,
|
|
||||||
warp_mode='PIL'):
|
|
||||||
"""SpecAugment class.
|
|
||||||
Args:
|
|
||||||
rng (random.Random): random generator object.
|
|
||||||
F (int): parameter for frequency masking
|
|
||||||
T (int): parameter for time masking
|
|
||||||
n_freq_masks (int): number of frequency masks
|
|
||||||
n_time_masks (int): number of time masks
|
|
||||||
p (float): parameter for upperbound of the time mask
|
|
||||||
W (int): parameter for time warping
|
|
||||||
adaptive_number_ratio (float): adaptive multiplicity ratio for time masking
|
|
||||||
adaptive_size_ratio (float): adaptive size ratio for time masking
|
|
||||||
max_n_time_masks (int): maximum number of time masking
|
|
||||||
replace_with_zero (bool): pad zero on mask if true else use mean
|
|
||||||
warp_mode (str): "PIL" (default, fast, not differentiable)
|
|
||||||
or "sparse_image_warp" (slow, differentiable)
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self._rng = rng
|
|
||||||
self.inplace = True
|
|
||||||
self.replace_with_zero = replace_with_zero
|
|
||||||
|
|
||||||
self.mode = warp_mode
|
|
||||||
self.W = W
|
|
||||||
self.F = F
|
|
||||||
self.T = T
|
|
||||||
self.n_freq_masks = n_freq_masks
|
|
||||||
self.n_time_masks = n_time_masks
|
|
||||||
self.p = p
|
|
||||||
|
|
||||||
# adaptive SpecAugment
|
|
||||||
self.adaptive_number_ratio = adaptive_number_ratio
|
|
||||||
self.adaptive_size_ratio = adaptive_size_ratio
|
|
||||||
self.max_n_time_masks = max_n_time_masks
|
|
||||||
|
|
||||||
if adaptive_number_ratio > 0:
|
|
||||||
self.n_time_masks = 0
|
|
||||||
logger.info('n_time_masks is set ot zero for adaptive SpecAugment.')
|
|
||||||
if adaptive_size_ratio > 0:
|
|
||||||
self.T = 0
|
|
||||||
logger.info('T is set to zero for adaptive SpecAugment.')
|
|
||||||
|
|
||||||
self._freq_mask = None
|
|
||||||
self._time_mask = None
|
|
||||||
|
|
||||||
def librispeech_basic(self):
|
|
||||||
self.W = 80
|
|
||||||
self.F = 27
|
|
||||||
self.T = 100
|
|
||||||
self.n_freq_masks = 1
|
|
||||||
self.n_time_masks = 1
|
|
||||||
self.p = 1.0
|
|
||||||
|
|
||||||
def librispeech_double(self):
|
|
||||||
self.W = 80
|
|
||||||
self.F = 27
|
|
||||||
self.T = 100
|
|
||||||
self.n_freq_masks = 2
|
|
||||||
self.n_time_masks = 2
|
|
||||||
self.p = 1.0
|
|
||||||
|
|
||||||
def switchboard_mild(self):
|
|
||||||
self.W = 40
|
|
||||||
self.F = 15
|
|
||||||
self.T = 70
|
|
||||||
self.n_freq_masks = 2
|
|
||||||
self.n_time_masks = 2
|
|
||||||
self.p = 0.2
|
|
||||||
|
|
||||||
def switchboard_strong(self):
|
|
||||||
self.W = 40
|
|
||||||
self.F = 27
|
|
||||||
self.T = 70
|
|
||||||
self.n_freq_masks = 2
|
|
||||||
self.n_time_masks = 2
|
|
||||||
self.p = 0.2
|
|
||||||
|
|
||||||
@property
|
|
||||||
def freq_mask(self):
|
|
||||||
return self._freq_mask
|
|
||||||
|
|
||||||
@property
|
|
||||||
def time_mask(self):
|
|
||||||
return self._time_mask
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"specaug: F-{F}, T-{T}, F-n-{n_freq_masks}, T-n-{n_time_masks}"
|
|
||||||
|
|
||||||
def time_warp(self, x, mode='PIL'):
|
|
||||||
"""time warp for spec augment
|
|
||||||
move random center frame by the random width ~ uniform(-window, window)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (np.ndarray): spectrogram (time, freq)
|
|
||||||
mode (str): PIL or sparse_image_warp
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotImplementedError: [description]
|
|
||||||
NotImplementedError: [description]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: time warped spectrogram (time, freq)
|
|
||||||
"""
|
|
||||||
window = max_time_warp = self.W
|
|
||||||
if window == 0:
|
|
||||||
return x
|
|
||||||
|
|
||||||
if mode == "PIL":
|
|
||||||
t = x.shape[0]
|
|
||||||
if t - window <= window:
|
|
||||||
return x
|
|
||||||
# NOTE: randrange(a, b) emits a, a + 1, ..., b - 1
|
|
||||||
center = random.randrange(window, t - window)
|
|
||||||
warped = random.randrange(center - window, center +
|
|
||||||
window) + 1 # 1 ... t - 1
|
|
||||||
|
|
||||||
left = Image.fromarray(x[:center]).resize((x.shape[1], warped),
|
|
||||||
BICUBIC)
|
|
||||||
right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped),
|
|
||||||
BICUBIC)
|
|
||||||
if self.inplace:
|
|
||||||
x[:warped] = left
|
|
||||||
x[warped:] = right
|
|
||||||
return x
|
|
||||||
return np.concatenate((left, right), 0)
|
|
||||||
elif mode == "sparse_image_warp":
|
|
||||||
raise NotImplementedError('sparse_image_warp')
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"unknown resize mode: " + mode +
|
|
||||||
", choose one from (PIL, sparse_image_warp).")
|
|
||||||
|
|
||||||
def mask_freq(self, x, replace_with_zero=False):
|
|
||||||
"""freq mask
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (np.ndarray): spectrogram (time, freq)
|
|
||||||
replace_with_zero (bool, optional): Defaults to False.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: freq mask spectrogram (time, freq)
|
|
||||||
"""
|
|
||||||
n_bins = x.shape[1]
|
|
||||||
for i in range(0, self.n_freq_masks):
|
|
||||||
f = int(self._rng.uniform(low=0, high=self.F))
|
|
||||||
f_0 = int(self._rng.uniform(low=0, high=n_bins - f))
|
|
||||||
assert f_0 <= f_0 + f
|
|
||||||
if replace_with_zero:
|
|
||||||
x[:, f_0:f_0 + f] = 0
|
|
||||||
else:
|
|
||||||
x[:, f_0:f_0 + f] = x.mean()
|
|
||||||
self._freq_mask = (f_0, f_0 + f)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def mask_time(self, x, replace_with_zero=False):
|
|
||||||
"""time mask
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (np.ndarray): spectrogram (time, freq)
|
|
||||||
replace_with_zero (bool, optional): Defaults to False.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: time mask spectrogram (time, freq)
|
|
||||||
"""
|
|
||||||
n_frames = x.shape[0]
|
|
||||||
|
|
||||||
if self.adaptive_number_ratio > 0:
|
|
||||||
n_masks = int(n_frames * self.adaptive_number_ratio)
|
|
||||||
n_masks = min(n_masks, self.max_n_time_masks)
|
|
||||||
else:
|
|
||||||
n_masks = self.n_time_masks
|
|
||||||
|
|
||||||
if self.adaptive_size_ratio > 0:
|
|
||||||
T = self.adaptive_size_ratio * n_frames
|
|
||||||
else:
|
|
||||||
T = self.T
|
|
||||||
|
|
||||||
for i in range(n_masks):
|
|
||||||
t = int(self._rng.uniform(low=0, high=T))
|
|
||||||
t = min(t, int(n_frames * self.p))
|
|
||||||
t_0 = int(self._rng.uniform(low=0, high=n_frames - t))
|
|
||||||
assert t_0 <= t_0 + t
|
|
||||||
if replace_with_zero:
|
|
||||||
x[t_0:t_0 + t, :] = 0
|
|
||||||
else:
|
|
||||||
x[t_0:t_0 + t, :] = x.mean()
|
|
||||||
self._time_mask = (t_0, t_0 + t)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def __call__(self, x, train=True):
|
|
||||||
if not train:
|
|
||||||
return x
|
|
||||||
return self.transform_feature(x)
|
|
||||||
|
|
||||||
def transform_feature(self, x: np.ndarray):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x (np.ndarray): `[T, F]`
|
|
||||||
Returns:
|
|
||||||
x (np.ndarray): `[T, F]`
|
|
||||||
"""
|
|
||||||
assert isinstance(x, np.ndarray)
|
|
||||||
assert x.ndim == 2
|
|
||||||
x = self.time_warp(x, self.mode)
|
|
||||||
x = self.mask_freq(x, self.replace_with_zero)
|
|
||||||
x = self.mask_time(x, self.replace_with_zero)
|
|
||||||
return x
|
|
||||||
@ -1,106 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contain the speech perturbation augmentation model."""
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from deepspeech.frontend.augmentor.base import AugmentorBase
|
|
||||||
|
|
||||||
|
|
||||||
class SpeedPerturbAugmentor(AugmentorBase):
|
|
||||||
"""Augmentation model for adding speed perturbation."""
|
|
||||||
|
|
||||||
def __init__(self, rng, min_speed_rate=0.9, max_speed_rate=1.1,
|
|
||||||
num_rates=3):
|
|
||||||
"""speed perturbation.
|
|
||||||
|
|
||||||
The speed perturbation in kaldi uses sox-speed instead of sox-tempo,
|
|
||||||
and sox-speed just to resample the input,
|
|
||||||
i.e pitch and tempo are changed both.
|
|
||||||
|
|
||||||
"Why use speed option instead of tempo -s in SoX for speed perturbation"
|
|
||||||
https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8
|
|
||||||
|
|
||||||
Sox speed:
|
|
||||||
https://pysox.readthedocs.io/en/latest/api.html#sox.transform.Transformer
|
|
||||||
|
|
||||||
See reference paper here:
|
|
||||||
http://www.danielpovey.com/files/2015_interspeech_augmentation.pdf
|
|
||||||
|
|
||||||
Espnet:
|
|
||||||
https://espnet.github.io/espnet/_modules/espnet/transform/perturb.html
|
|
||||||
|
|
||||||
Nemo:
|
|
||||||
https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/perturb.py#L92
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rng (random.Random): Random generator object.
|
|
||||||
min_speed_rate (float): Lower bound of new speed rate to sample and should
|
|
||||||
not be smaller than 0.9.
|
|
||||||
max_speed_rate (float): Upper bound of new speed rate to sample and should
|
|
||||||
not be larger than 1.1.
|
|
||||||
num_rates (int, optional): Number of discrete rates to allow.
|
|
||||||
Can be a positive or negative integer. Defaults to 3.
|
|
||||||
If a positive integer greater than 0 is provided, the range of
|
|
||||||
speed rates will be discretized into `num_rates` values.
|
|
||||||
If a negative integer or 0 is provided, the full range of speed rates
|
|
||||||
will be sampled uniformly.
|
|
||||||
Note: If a positive integer is provided and the resultant discretized
|
|
||||||
range of rates contains the value '1.0', then those samples with rate=1.0,
|
|
||||||
will not be augmented at all and simply skipped. This is to unnecessary
|
|
||||||
augmentation and increase computation time. Effective augmentation chance
|
|
||||||
in such a case is = `prob * (num_rates - 1 / num_rates) * 100`% chance
|
|
||||||
where `prob` is the global probability of a sample being augmented.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: when speed_rate error
|
|
||||||
"""
|
|
||||||
if min_speed_rate < 0.9:
|
|
||||||
raise ValueError(
|
|
||||||
"Sampling speed below 0.9 can cause unnatural effects")
|
|
||||||
if max_speed_rate > 1.1:
|
|
||||||
raise ValueError(
|
|
||||||
"Sampling speed above 1.1 can cause unnatural effects")
|
|
||||||
self._min_rate = min_speed_rate
|
|
||||||
self._max_rate = max_speed_rate
|
|
||||||
self._rng = rng
|
|
||||||
self._num_rates = num_rates
|
|
||||||
if num_rates > 0:
|
|
||||||
self._rates = np.linspace(
|
|
||||||
self._min_rate, self._max_rate, self._num_rates, endpoint=True)
|
|
||||||
|
|
||||||
def __call__(self, x, uttid=None, train=True):
|
|
||||||
if not train:
|
|
||||||
return x
|
|
||||||
self.transform_audio(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def transform_audio(self, audio_segment):
|
|
||||||
"""Sample a new speed rate from the given range and
|
|
||||||
changes the speed of the given audio clip.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param audio_segment: Audio segment to add effects to.
|
|
||||||
:type audio_segment: AudioSegment|SpeechSegment
|
|
||||||
"""
|
|
||||||
if self._num_rates < 0:
|
|
||||||
speed_rate = self._rng.uniform(self._min_rate, self._max_rate)
|
|
||||||
else:
|
|
||||||
speed_rate = self._rng.choice(self._rates)
|
|
||||||
|
|
||||||
# Skip perturbation in case of identity speed rate
|
|
||||||
if speed_rate == 1.0:
|
|
||||||
return
|
|
||||||
|
|
||||||
audio_segment.change_speed(speed_rate)
|
|
||||||
@ -1,55 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains the volume perturb augmentation model."""
|
|
||||||
from deepspeech.frontend.augmentor.base import AugmentorBase
|
|
||||||
|
|
||||||
|
|
||||||
class VolumePerturbAugmentor(AugmentorBase):
|
|
||||||
"""Augmentation model for adding random volume perturbation.
|
|
||||||
|
|
||||||
This is used for multi-loudness training of PCEN. See
|
|
||||||
|
|
||||||
https://arxiv.org/pdf/1607.05666v1.pdf
|
|
||||||
|
|
||||||
for more details.
|
|
||||||
|
|
||||||
:param rng: Random generator object.
|
|
||||||
:type rng: random.Random
|
|
||||||
:param min_gain_dBFS: Minimal gain in dBFS.
|
|
||||||
:type min_gain_dBFS: float
|
|
||||||
:param max_gain_dBFS: Maximal gain in dBFS.
|
|
||||||
:type max_gain_dBFS: float
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, rng, min_gain_dBFS, max_gain_dBFS):
|
|
||||||
self._min_gain_dBFS = min_gain_dBFS
|
|
||||||
self._max_gain_dBFS = max_gain_dBFS
|
|
||||||
self._rng = rng
|
|
||||||
|
|
||||||
def __call__(self, x, uttid=None, train=True):
|
|
||||||
if not train:
|
|
||||||
return x
|
|
||||||
self.transform_audio(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def transform_audio(self, audio_segment):
|
|
||||||
"""Change audio loadness.
|
|
||||||
|
|
||||||
Note that this is an in-place transformation.
|
|
||||||
|
|
||||||
:param audio_segment: Audio segment to add effects to.
|
|
||||||
:type audio_segment: AudioSegmenet|SpeechSegment
|
|
||||||
"""
|
|
||||||
gain = self._rng.uniform(self._min_gain_dBFS, self._max_gain_dBFS)
|
|
||||||
audio_segment.gain_db(gain)
|
|
||||||
@ -1,16 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
from .audio_featurizer import AudioFeaturizer #noqa: F401
|
|
||||||
from .speech_featurizer import SpeechFeaturizer
|
|
||||||
from .text_featurizer import TextFeaturizer
|
|
||||||
@ -1,363 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains the audio featurizer class."""
|
|
||||||
import numpy as np
|
|
||||||
from python_speech_features import delta
|
|
||||||
from python_speech_features import logfbank
|
|
||||||
from python_speech_features import mfcc
|
|
||||||
|
|
||||||
|
|
||||||
class AudioFeaturizer():
|
|
||||||
"""Audio featurizer, for extracting features from audio contents of
|
|
||||||
AudioSegment or SpeechSegment.
|
|
||||||
|
|
||||||
Currently, it supports feature types of linear spectrogram and mfcc.
|
|
||||||
|
|
||||||
:param specgram_type: Specgram feature type. Options: 'linear'.
|
|
||||||
:type specgram_type: str
|
|
||||||
:param stride_ms: Striding size (in milliseconds) for generating frames.
|
|
||||||
:type stride_ms: float
|
|
||||||
:param window_ms: Window size (in milliseconds) for generating frames.
|
|
||||||
:type window_ms: float
|
|
||||||
:param max_freq: When specgram_type is 'linear', only FFT bins
|
|
||||||
corresponding to frequencies between [0, max_freq] are
|
|
||||||
returned; when specgram_type is 'mfcc', max_feq is the
|
|
||||||
highest band edge of mel filters.
|
|
||||||
:types max_freq: None|float
|
|
||||||
:param target_sample_rate: Audio are resampled (if upsampling or
|
|
||||||
downsampling is allowed) to this before
|
|
||||||
extracting spectrogram features.
|
|
||||||
:type target_sample_rate: float
|
|
||||||
:param use_dB_normalization: Whether to normalize the audio to a certain
|
|
||||||
decibels before extracting the features.
|
|
||||||
:type use_dB_normalization: bool
|
|
||||||
:param target_dB: Target audio decibels for normalization.
|
|
||||||
:type target_dB: float
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
specgram_type: str='linear',
|
|
||||||
feat_dim: int=None,
|
|
||||||
delta_delta: bool=False,
|
|
||||||
stride_ms=10.0,
|
|
||||||
window_ms=20.0,
|
|
||||||
n_fft=None,
|
|
||||||
max_freq=None,
|
|
||||||
target_sample_rate=16000,
|
|
||||||
use_dB_normalization=True,
|
|
||||||
target_dB=-20,
|
|
||||||
dither=1.0):
|
|
||||||
self._specgram_type = specgram_type
|
|
||||||
# mfcc and fbank using `feat_dim`
|
|
||||||
self._feat_dim = feat_dim
|
|
||||||
# mfcc and fbank using `delta-delta`
|
|
||||||
self._delta_delta = delta_delta
|
|
||||||
self._stride_ms = stride_ms
|
|
||||||
self._window_ms = window_ms
|
|
||||||
self._max_freq = max_freq
|
|
||||||
self._target_sample_rate = target_sample_rate
|
|
||||||
self._use_dB_normalization = use_dB_normalization
|
|
||||||
self._target_dB = target_dB
|
|
||||||
self._fft_point = n_fft
|
|
||||||
self._dither = dither
|
|
||||||
|
|
||||||
def featurize(self,
|
|
||||||
audio_segment,
|
|
||||||
allow_downsampling=True,
|
|
||||||
allow_upsampling=True):
|
|
||||||
"""Extract audio features from AudioSegment or SpeechSegment.
|
|
||||||
|
|
||||||
:param audio_segment: Audio/speech segment to extract features from.
|
|
||||||
:type audio_segment: AudioSegment|SpeechSegment
|
|
||||||
:param allow_downsampling: Whether to allow audio downsampling before
|
|
||||||
featurizing.
|
|
||||||
:type allow_downsampling: bool
|
|
||||||
:param allow_upsampling: Whether to allow audio upsampling before
|
|
||||||
featurizing.
|
|
||||||
:type allow_upsampling: bool
|
|
||||||
:return: Spectrogram audio feature in 2darray.
|
|
||||||
:rtype: ndarray
|
|
||||||
:raises ValueError: If audio sample rate is not supported.
|
|
||||||
"""
|
|
||||||
# upsampling or downsampling
|
|
||||||
if ((audio_segment.sample_rate > self._target_sample_rate and
|
|
||||||
allow_downsampling) or
|
|
||||||
(audio_segment.sample_rate < self._target_sample_rate and
|
|
||||||
allow_upsampling)):
|
|
||||||
audio_segment.resample(self._target_sample_rate)
|
|
||||||
if audio_segment.sample_rate != self._target_sample_rate:
|
|
||||||
raise ValueError("Audio sample rate is not supported. "
|
|
||||||
"Turn allow_downsampling or allow up_sampling on.")
|
|
||||||
# decibel normalization
|
|
||||||
if self._use_dB_normalization:
|
|
||||||
audio_segment.normalize(target_db=self._target_dB)
|
|
||||||
# extract spectrogram
|
|
||||||
return self._compute_specgram(audio_segment)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def stride_ms(self):
|
|
||||||
return self._stride_ms
|
|
||||||
|
|
||||||
@property
|
|
||||||
def feature_size(self):
|
|
||||||
"""audio feature size"""
|
|
||||||
feat_dim = 0
|
|
||||||
if self._specgram_type == 'linear':
|
|
||||||
fft_point = self._window_ms if self._fft_point is None else self._fft_point
|
|
||||||
feat_dim = int(fft_point * (self._target_sample_rate / 1000) / 2 +
|
|
||||||
1)
|
|
||||||
elif self._specgram_type == 'mfcc':
|
|
||||||
# mfcc, delta, delta-delta
|
|
||||||
feat_dim = int(self._feat_dim *
|
|
||||||
3) if self._delta_delta else int(self._feat_dim)
|
|
||||||
elif self._specgram_type == 'fbank':
|
|
||||||
# fbank, delta, delta-delta
|
|
||||||
feat_dim = int(self._feat_dim *
|
|
||||||
3) if self._delta_delta else int(self._feat_dim)
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown specgram_type %s. "
|
|
||||||
"Supported values: linear." % self._specgram_type)
|
|
||||||
return feat_dim
|
|
||||||
|
|
||||||
def _compute_specgram(self, audio_segment):
|
|
||||||
"""Extract various audio features."""
|
|
||||||
sample_rate = audio_segment.sample_rate
|
|
||||||
if self._specgram_type == 'linear':
|
|
||||||
samples = audio_segment.samples
|
|
||||||
return self._compute_linear_specgram(
|
|
||||||
samples,
|
|
||||||
sample_rate,
|
|
||||||
stride_ms=self._stride_ms,
|
|
||||||
window_ms=self._window_ms,
|
|
||||||
max_freq=self._max_freq)
|
|
||||||
elif self._specgram_type == 'mfcc':
|
|
||||||
samples = audio_segment.to('int16')
|
|
||||||
return self._compute_mfcc(
|
|
||||||
samples,
|
|
||||||
sample_rate,
|
|
||||||
feat_dim=self._feat_dim,
|
|
||||||
stride_ms=self._stride_ms,
|
|
||||||
window_ms=self._window_ms,
|
|
||||||
max_freq=self._max_freq,
|
|
||||||
dither=self._dither,
|
|
||||||
delta_delta=self._delta_delta)
|
|
||||||
elif self._specgram_type == 'fbank':
|
|
||||||
samples = audio_segment.to('int16')
|
|
||||||
return self._compute_fbank(
|
|
||||||
samples,
|
|
||||||
sample_rate,
|
|
||||||
feat_dim=self._feat_dim,
|
|
||||||
stride_ms=self._stride_ms,
|
|
||||||
window_ms=self._window_ms,
|
|
||||||
max_freq=self._max_freq,
|
|
||||||
dither=self._dither,
|
|
||||||
delta_delta=self._delta_delta)
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown specgram_type %s. "
|
|
||||||
"Supported values: linear." % self._specgram_type)
|
|
||||||
|
|
||||||
def _specgram_real(self, samples, window_size, stride_size, sample_rate):
|
|
||||||
"""Compute the spectrogram for samples from a real signal."""
|
|
||||||
# extract strided windows
|
|
||||||
truncate_size = (len(samples) - window_size) % stride_size
|
|
||||||
samples = samples[:len(samples) - truncate_size]
|
|
||||||
nshape = (window_size, (len(samples) - window_size) // stride_size + 1)
|
|
||||||
nstrides = (samples.strides[0], samples.strides[0] * stride_size)
|
|
||||||
windows = np.lib.stride_tricks.as_strided(
|
|
||||||
samples, shape=nshape, strides=nstrides)
|
|
||||||
assert np.all(
|
|
||||||
windows[:, 1] == samples[stride_size:(stride_size + window_size)])
|
|
||||||
# window weighting, squared Fast Fourier Transform (fft), scaling
|
|
||||||
weighting = np.hanning(window_size)[:, None]
|
|
||||||
# https://numpy.org/doc/stable/reference/generated/numpy.fft.rfft.html
|
|
||||||
fft = np.fft.rfft(windows * weighting, n=None, axis=0)
|
|
||||||
fft = np.absolute(fft)
|
|
||||||
fft = fft**2
|
|
||||||
scale = np.sum(weighting**2) * sample_rate
|
|
||||||
fft[1:-1, :] *= (2.0 / scale)
|
|
||||||
fft[(0, -1), :] /= scale
|
|
||||||
# prepare fft frequency list
|
|
||||||
freqs = float(sample_rate) / window_size * np.arange(fft.shape[0])
|
|
||||||
return fft, freqs
|
|
||||||
|
|
||||||
def _compute_linear_specgram(self,
|
|
||||||
samples,
|
|
||||||
sample_rate,
|
|
||||||
stride_ms=10.0,
|
|
||||||
window_ms=20.0,
|
|
||||||
max_freq=None,
|
|
||||||
eps=1e-14):
|
|
||||||
"""Compute the linear spectrogram from FFT energy.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
samples ([type]): [description]
|
|
||||||
sample_rate ([type]): [description]
|
|
||||||
stride_ms (float, optional): [description]. Defaults to 10.0.
|
|
||||||
window_ms (float, optional): [description]. Defaults to 20.0.
|
|
||||||
max_freq ([type], optional): [description]. Defaults to None.
|
|
||||||
eps ([type], optional): [description]. Defaults to 1e-14.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: [description]
|
|
||||||
ValueError: [description]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: log spectrogram, (time, freq)
|
|
||||||
"""
|
|
||||||
if max_freq is None:
|
|
||||||
max_freq = sample_rate / 2
|
|
||||||
if max_freq > sample_rate / 2:
|
|
||||||
raise ValueError("max_freq must not be greater than half of "
|
|
||||||
"sample rate.")
|
|
||||||
if stride_ms > window_ms:
|
|
||||||
raise ValueError("Stride size must not be greater than "
|
|
||||||
"window size.")
|
|
||||||
stride_size = int(0.001 * sample_rate * stride_ms)
|
|
||||||
window_size = int(0.001 * sample_rate * window_ms)
|
|
||||||
specgram, freqs = self._specgram_real(
|
|
||||||
samples,
|
|
||||||
window_size=window_size,
|
|
||||||
stride_size=stride_size,
|
|
||||||
sample_rate=sample_rate)
|
|
||||||
ind = np.where(freqs <= max_freq)[0][-1] + 1
|
|
||||||
# (freq, time)
|
|
||||||
spec = np.log(specgram[:ind, :] + eps)
|
|
||||||
return np.transpose(spec)
|
|
||||||
|
|
||||||
def _concat_delta_delta(self, feat):
|
|
||||||
"""append delat, delta-delta feature.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
feat (np.ndarray): (T, D)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: feat with delta-delta, (T, 3*D)
|
|
||||||
"""
|
|
||||||
# Deltas
|
|
||||||
d_feat = delta(feat, 2)
|
|
||||||
# Deltas-Deltas
|
|
||||||
dd_feat = delta(feat, 2)
|
|
||||||
# concat above three features
|
|
||||||
concat_feat = np.concatenate((feat, d_feat, dd_feat), axis=1)
|
|
||||||
return concat_feat
|
|
||||||
|
|
||||||
def _compute_mfcc(self,
|
|
||||||
samples,
|
|
||||||
sample_rate,
|
|
||||||
feat_dim=13,
|
|
||||||
stride_ms=10.0,
|
|
||||||
window_ms=25.0,
|
|
||||||
max_freq=None,
|
|
||||||
dither=1.0,
|
|
||||||
delta_delta=True):
|
|
||||||
"""Compute mfcc from samples.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
samples (np.ndarray, np.int16): the audio signal from which to compute features.
|
|
||||||
sample_rate (float): the sample rate of the signal we are working with, in Hz.
|
|
||||||
feat_dim (int): the number of cepstrum to return, default 13.
|
|
||||||
stride_ms (float, optional): stride length in ms. Defaults to 10.0.
|
|
||||||
window_ms (float, optional): window length in ms. Defaults to 25.0.
|
|
||||||
max_freq ([type], optional): highest band edge of mel filters. In Hz, default is samplerate/2. Defaults to None.
|
|
||||||
delta_delta (bool, optional): Whether with delta delta. Defaults to False.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: max_freq > samplerate/2
|
|
||||||
ValueError: stride_ms > window_ms
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: mfcc feature, (D, T).
|
|
||||||
"""
|
|
||||||
if max_freq is None:
|
|
||||||
max_freq = sample_rate / 2
|
|
||||||
if max_freq > sample_rate / 2:
|
|
||||||
raise ValueError("max_freq must not be greater than half of "
|
|
||||||
"sample rate.")
|
|
||||||
if stride_ms > window_ms:
|
|
||||||
raise ValueError("Stride size must not be greater than "
|
|
||||||
"window size.")
|
|
||||||
# compute the 13 cepstral coefficients, and the first one is replaced
|
|
||||||
# by log(frame energy), (T, D)
|
|
||||||
mfcc_feat = mfcc(
|
|
||||||
signal=samples,
|
|
||||||
samplerate=sample_rate,
|
|
||||||
winlen=0.001 * window_ms,
|
|
||||||
winstep=0.001 * stride_ms,
|
|
||||||
numcep=feat_dim,
|
|
||||||
nfilt=23,
|
|
||||||
nfft=512,
|
|
||||||
lowfreq=20,
|
|
||||||
highfreq=max_freq,
|
|
||||||
dither=dither,
|
|
||||||
remove_dc_offset=True,
|
|
||||||
preemph=0.97,
|
|
||||||
ceplifter=22,
|
|
||||||
useEnergy=True,
|
|
||||||
winfunc='povey')
|
|
||||||
if delta_delta:
|
|
||||||
mfcc_feat = self._concat_delta_delta(mfcc_feat)
|
|
||||||
return mfcc_feat
|
|
||||||
|
|
||||||
def _compute_fbank(self,
|
|
||||||
samples,
|
|
||||||
sample_rate,
|
|
||||||
feat_dim=40,
|
|
||||||
stride_ms=10.0,
|
|
||||||
window_ms=25.0,
|
|
||||||
max_freq=None,
|
|
||||||
dither=1.0,
|
|
||||||
delta_delta=False):
|
|
||||||
"""Compute logfbank from samples.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
samples (np.ndarray, np.int16): the audio signal from which to compute features. Should be an N*1 array
|
|
||||||
sample_rate (float): the sample rate of the signal we are working with, in Hz.
|
|
||||||
feat_dim (int): the number of cepstrum to return, default 13.
|
|
||||||
stride_ms (float, optional): stride length in ms. Defaults to 10.0.
|
|
||||||
window_ms (float, optional): window length in ms. Defaults to 20.0.
|
|
||||||
max_freq (float, optional): highest band edge of mel filters. In Hz, default is samplerate/2. Defaults to None.
|
|
||||||
delta_delta (bool, optional): Whether with delta delta. Defaults to False.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: max_freq > samplerate/2
|
|
||||||
ValueError: stride_ms > window_ms
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: mfcc feature, (D, T).
|
|
||||||
"""
|
|
||||||
if max_freq is None:
|
|
||||||
max_freq = sample_rate / 2
|
|
||||||
if max_freq > sample_rate / 2:
|
|
||||||
raise ValueError("max_freq must not be greater than half of "
|
|
||||||
"sample rate.")
|
|
||||||
if stride_ms > window_ms:
|
|
||||||
raise ValueError("Stride size must not be greater than "
|
|
||||||
"window size.")
|
|
||||||
# (T, D)
|
|
||||||
fbank_feat = logfbank(
|
|
||||||
signal=samples,
|
|
||||||
samplerate=sample_rate,
|
|
||||||
winlen=0.001 * window_ms,
|
|
||||||
winstep=0.001 * stride_ms,
|
|
||||||
nfilt=feat_dim,
|
|
||||||
nfft=512,
|
|
||||||
lowfreq=20,
|
|
||||||
highfreq=max_freq,
|
|
||||||
dither=dither,
|
|
||||||
remove_dc_offset=True,
|
|
||||||
preemph=0.97,
|
|
||||||
wintype='povey')
|
|
||||||
if delta_delta:
|
|
||||||
fbank_feat = self._concat_delta_delta(fbank_feat)
|
|
||||||
return fbank_feat
|
|
||||||
@ -1,153 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains the speech featurizer class."""
|
|
||||||
from deepspeech.frontend.featurizer.audio_featurizer import AudioFeaturizer
|
|
||||||
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
|
|
||||||
|
|
||||||
|
|
||||||
class SpeechFeaturizer():
|
|
||||||
"""Speech featurizer, for extracting features from both audio and transcript
|
|
||||||
contents of SpeechSegment.
|
|
||||||
|
|
||||||
Currently, for audio parts, it supports feature types of linear
|
|
||||||
spectrogram and mfcc; for transcript parts, it only supports char-level
|
|
||||||
tokenizing and conversion into a list of token indices. Note that the
|
|
||||||
token indexing order follows the given vocabulary file.
|
|
||||||
|
|
||||||
:param vocab_filepath: Filepath to load vocabulary for token indices
|
|
||||||
conversion.
|
|
||||||
:type specgram_type: str
|
|
||||||
:param specgram_type: Specgram feature type. Options: 'linear', 'mfcc'.
|
|
||||||
:type specgram_type: str
|
|
||||||
:param stride_ms: Striding size (in milliseconds) for generating frames.
|
|
||||||
:type stride_ms: float
|
|
||||||
:param window_ms: Window size (in milliseconds) for generating frames.
|
|
||||||
:type window_ms: float
|
|
||||||
:param max_freq: When specgram_type is 'linear', only FFT bins
|
|
||||||
corresponding to frequencies between [0, max_freq] are
|
|
||||||
returned; when specgram_type is 'mfcc', max_freq is the
|
|
||||||
highest band edge of mel filters.
|
|
||||||
:types max_freq: None|float
|
|
||||||
:param target_sample_rate: Speech are resampled (if upsampling or
|
|
||||||
downsampling is allowed) to this before
|
|
||||||
extracting spectrogram features.
|
|
||||||
:type target_sample_rate: float
|
|
||||||
:param use_dB_normalization: Whether to normalize the audio to a certain
|
|
||||||
decibels before extracting the features.
|
|
||||||
:type use_dB_normalization: bool
|
|
||||||
:param target_dB: Target audio decibels for normalization.
|
|
||||||
:type target_dB: float
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
unit_type,
|
|
||||||
vocab_filepath,
|
|
||||||
spm_model_prefix=None,
|
|
||||||
specgram_type='linear',
|
|
||||||
feat_dim=None,
|
|
||||||
delta_delta=False,
|
|
||||||
stride_ms=10.0,
|
|
||||||
window_ms=20.0,
|
|
||||||
n_fft=None,
|
|
||||||
max_freq=None,
|
|
||||||
target_sample_rate=16000,
|
|
||||||
use_dB_normalization=True,
|
|
||||||
target_dB=-20,
|
|
||||||
dither=1.0):
|
|
||||||
self._audio_featurizer = AudioFeaturizer(
|
|
||||||
specgram_type=specgram_type,
|
|
||||||
feat_dim=feat_dim,
|
|
||||||
delta_delta=delta_delta,
|
|
||||||
stride_ms=stride_ms,
|
|
||||||
window_ms=window_ms,
|
|
||||||
n_fft=n_fft,
|
|
||||||
max_freq=max_freq,
|
|
||||||
target_sample_rate=target_sample_rate,
|
|
||||||
use_dB_normalization=use_dB_normalization,
|
|
||||||
target_dB=target_dB,
|
|
||||||
dither=dither)
|
|
||||||
self._text_featurizer = TextFeaturizer(unit_type, vocab_filepath,
|
|
||||||
spm_model_prefix)
|
|
||||||
|
|
||||||
def featurize(self, speech_segment, keep_transcription_text):
|
|
||||||
"""Extract features for speech segment.
|
|
||||||
|
|
||||||
1. For audio parts, extract the audio features.
|
|
||||||
2. For transcript parts, keep the original text or convert text string
|
|
||||||
to a list of token indices in char-level.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
speech_segment (SpeechSegment): Speech segment to extract features from.
|
|
||||||
keep_transcription_text (bool): True, keep transcript text, False, token ids
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: 1) spectrogram audio feature in 2darray, 2) list oftoken indices.
|
|
||||||
"""
|
|
||||||
spec_feature = self._audio_featurizer.featurize(speech_segment)
|
|
||||||
if keep_transcription_text:
|
|
||||||
return spec_feature, speech_segment.transcript
|
|
||||||
if speech_segment.has_token:
|
|
||||||
text_ids = speech_segment.token_ids
|
|
||||||
else:
|
|
||||||
text_ids = self._text_featurizer.featurize(
|
|
||||||
speech_segment.transcript)
|
|
||||||
return spec_feature, text_ids
|
|
||||||
|
|
||||||
@property
|
|
||||||
def vocab_size(self):
|
|
||||||
"""Return the vocabulary size.
|
|
||||||
Returns:
|
|
||||||
int: Vocabulary size.
|
|
||||||
"""
|
|
||||||
return self._text_featurizer.vocab_size
|
|
||||||
|
|
||||||
@property
|
|
||||||
def vocab_list(self):
|
|
||||||
"""Return the vocabulary in list.
|
|
||||||
Returns:
|
|
||||||
List[str]:
|
|
||||||
"""
|
|
||||||
return self._text_featurizer.vocab_list
|
|
||||||
|
|
||||||
@property
|
|
||||||
def vocab_dict(self):
|
|
||||||
"""Return the vocabulary in dict.
|
|
||||||
Returns:
|
|
||||||
Dict[str, int]:
|
|
||||||
"""
|
|
||||||
return self._text_featurizer.vocab_dict
|
|
||||||
|
|
||||||
@property
|
|
||||||
def feature_size(self):
|
|
||||||
"""Return the audio feature size.
|
|
||||||
Returns:
|
|
||||||
int: audio feature size.
|
|
||||||
"""
|
|
||||||
return self._audio_featurizer.feature_size
|
|
||||||
|
|
||||||
@property
|
|
||||||
def stride_ms(self):
|
|
||||||
"""time length in `ms` unit per frame
|
|
||||||
Returns:
|
|
||||||
float: time(ms)/frame
|
|
||||||
"""
|
|
||||||
return self._audio_featurizer.stride_ms
|
|
||||||
|
|
||||||
@property
|
|
||||||
def text_feature(self):
|
|
||||||
"""Return the text feature object.
|
|
||||||
Returns:
|
|
||||||
TextFeaturizer: object.
|
|
||||||
"""
|
|
||||||
return self._text_featurizer
|
|
||||||
@ -1,202 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains the text featurizer class."""
|
|
||||||
import sentencepiece as spm
|
|
||||||
|
|
||||||
from ..utility import EOS
|
|
||||||
from ..utility import load_dict
|
|
||||||
from ..utility import UNK
|
|
||||||
|
|
||||||
__all__ = ["TextFeaturizer"]
|
|
||||||
|
|
||||||
|
|
||||||
class TextFeaturizer():
|
|
||||||
def __init__(self,
|
|
||||||
unit_type,
|
|
||||||
vocab_filepath,
|
|
||||||
spm_model_prefix=None,
|
|
||||||
maskctc=False):
|
|
||||||
"""Text featurizer, for processing or extracting features from text.
|
|
||||||
|
|
||||||
Currently, it supports char/word/sentence-piece level tokenizing and conversion into
|
|
||||||
a list of token indices. Note that the token indexing order follows the
|
|
||||||
given vocabulary file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
unit_type (str): unit type, e.g. char, word, spm
|
|
||||||
vocab_filepath (str): Filepath to load vocabulary for token indices conversion.
|
|
||||||
spm_model_prefix (str, optional): spm model prefix. Defaults to None.
|
|
||||||
"""
|
|
||||||
assert unit_type in ('char', 'spm', 'word')
|
|
||||||
self.unit_type = unit_type
|
|
||||||
self.unk = UNK
|
|
||||||
self.maskctc = maskctc
|
|
||||||
|
|
||||||
if vocab_filepath:
|
|
||||||
self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id = self._load_vocabulary_from_file(
|
|
||||||
vocab_filepath, maskctc)
|
|
||||||
self.vocab_size = len(self.vocab_list)
|
|
||||||
|
|
||||||
if unit_type == 'spm':
|
|
||||||
spm_model = spm_model_prefix + '.model'
|
|
||||||
self.sp = spm.SentencePieceProcessor()
|
|
||||||
self.sp.Load(spm_model)
|
|
||||||
|
|
||||||
def tokenize(self, text):
|
|
||||||
if self.unit_type == 'char':
|
|
||||||
tokens = self.char_tokenize(text)
|
|
||||||
elif self.unit_type == 'word':
|
|
||||||
tokens = self.word_tokenize(text)
|
|
||||||
else: # spm
|
|
||||||
tokens = self.spm_tokenize(text)
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
def detokenize(self, tokens):
|
|
||||||
if self.unit_type == 'char':
|
|
||||||
text = self.char_detokenize(tokens)
|
|
||||||
elif self.unit_type == 'word':
|
|
||||||
text = self.word_detokenize(tokens)
|
|
||||||
else: # spm
|
|
||||||
text = self.spm_detokenize(tokens)
|
|
||||||
return text
|
|
||||||
|
|
||||||
def featurize(self, text):
|
|
||||||
"""Convert text string to a list of token indices.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text (str): Text.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[int]: List of token indices.
|
|
||||||
"""
|
|
||||||
tokens = self.tokenize(text)
|
|
||||||
ids = []
|
|
||||||
for token in tokens:
|
|
||||||
token = token if token in self.vocab_dict else self.unk
|
|
||||||
ids.append(self.vocab_dict[token])
|
|
||||||
return ids
|
|
||||||
|
|
||||||
def defeaturize(self, idxs):
|
|
||||||
"""Convert a list of token indices to text string,
|
|
||||||
ignore index after eos_id.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
idxs (List[int]): List of token indices.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Text.
|
|
||||||
"""
|
|
||||||
tokens = []
|
|
||||||
for idx in idxs:
|
|
||||||
if idx == self.eos_id:
|
|
||||||
break
|
|
||||||
tokens.append(self._id2token[idx])
|
|
||||||
text = self.detokenize(tokens)
|
|
||||||
return text
|
|
||||||
|
|
||||||
def char_tokenize(self, text):
|
|
||||||
"""Character tokenizer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text (str): text string.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: tokens.
|
|
||||||
"""
|
|
||||||
return list(text.strip())
|
|
||||||
|
|
||||||
def char_detokenize(self, tokens):
|
|
||||||
"""Character detokenizer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tokens (List[str]): tokens.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: text string.
|
|
||||||
"""
|
|
||||||
return "".join(tokens)
|
|
||||||
|
|
||||||
def word_tokenize(self, text):
|
|
||||||
"""Word tokenizer, separate by <space>."""
|
|
||||||
return text.strip().split()
|
|
||||||
|
|
||||||
def word_detokenize(self, tokens):
|
|
||||||
"""Word detokenizer, separate by <space>."""
|
|
||||||
return " ".join(tokens)
|
|
||||||
|
|
||||||
def spm_tokenize(self, text):
|
|
||||||
"""spm tokenize.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text (str): text string.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: sentence pieces str code
|
|
||||||
"""
|
|
||||||
stats = {"num_empty": 0, "num_filtered": 0}
|
|
||||||
|
|
||||||
def valid(line):
|
|
||||||
return True
|
|
||||||
|
|
||||||
def encode(l):
|
|
||||||
return self.sp.EncodeAsPieces(l)
|
|
||||||
|
|
||||||
def encode_line(line):
|
|
||||||
line = line.strip()
|
|
||||||
if len(line) > 0:
|
|
||||||
line = encode(line)
|
|
||||||
if valid(line):
|
|
||||||
return line
|
|
||||||
else:
|
|
||||||
stats["num_filtered"] += 1
|
|
||||||
else:
|
|
||||||
stats["num_empty"] += 1
|
|
||||||
return None
|
|
||||||
|
|
||||||
enc_line = encode_line(text)
|
|
||||||
return enc_line
|
|
||||||
|
|
||||||
def spm_detokenize(self, tokens, input_format='piece'):
|
|
||||||
"""spm detokenize.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ids (List[str]): tokens.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: text
|
|
||||||
"""
|
|
||||||
if input_format == "piece":
|
|
||||||
|
|
||||||
def decode(l):
|
|
||||||
return "".join(self.sp.DecodePieces(l))
|
|
||||||
elif input_format == "id":
|
|
||||||
|
|
||||||
def decode(l):
|
|
||||||
return "".join(self.sp.DecodeIds(l))
|
|
||||||
|
|
||||||
return decode(tokens)
|
|
||||||
|
|
||||||
def _load_vocabulary_from_file(self, vocab_filepath: str, maskctc: bool):
|
|
||||||
"""Load vocabulary from file."""
|
|
||||||
vocab_list = load_dict(vocab_filepath, maskctc)
|
|
||||||
assert vocab_list is not None
|
|
||||||
|
|
||||||
id2token = dict(
|
|
||||||
[(idx, token) for (idx, token) in enumerate(vocab_list)])
|
|
||||||
token2id = dict(
|
|
||||||
[(token, idx) for (idx, token) in enumerate(vocab_list)])
|
|
||||||
|
|
||||||
unk_id = vocab_list.index(UNK)
|
|
||||||
eos_id = vocab_list.index(EOS)
|
|
||||||
return token2id, id2token, vocab_list, unk_id, eos_id
|
|
||||||
@ -1,199 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains feature normalizers."""
|
|
||||||
import json
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import paddle
|
|
||||||
from paddle.io import DataLoader
|
|
||||||
from paddle.io import Dataset
|
|
||||||
|
|
||||||
from deepspeech.frontend.audio import AudioSegment
|
|
||||||
from deepspeech.frontend.utility import load_cmvn
|
|
||||||
from deepspeech.frontend.utility import read_manifest
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
|
|
||||||
__all__ = ["FeatureNormalizer"]
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/PaddlePaddle/Paddle/pull/31481
|
|
||||||
class CollateFunc(object):
|
|
||||||
def __init__(self, feature_func):
|
|
||||||
self.feature_func = feature_func
|
|
||||||
|
|
||||||
def __call__(self, batch):
|
|
||||||
mean_stat = None
|
|
||||||
var_stat = None
|
|
||||||
number = 0
|
|
||||||
for item in batch:
|
|
||||||
audioseg = AudioSegment.from_file(item['feat'])
|
|
||||||
feat = self.feature_func(audioseg) #(T, D)
|
|
||||||
|
|
||||||
sums = np.sum(feat, axis=0)
|
|
||||||
if mean_stat is None:
|
|
||||||
mean_stat = sums
|
|
||||||
else:
|
|
||||||
mean_stat += sums
|
|
||||||
|
|
||||||
square_sums = np.sum(np.square(feat), axis=0)
|
|
||||||
if var_stat is None:
|
|
||||||
var_stat = square_sums
|
|
||||||
else:
|
|
||||||
var_stat += square_sums
|
|
||||||
|
|
||||||
number += feat.shape[0]
|
|
||||||
return number, mean_stat, var_stat
|
|
||||||
|
|
||||||
|
|
||||||
class AudioDataset(Dataset):
|
|
||||||
def __init__(self, manifest_path, num_samples=-1, rng=None, random_seed=0):
|
|
||||||
self._rng = rng if rng else np.random.RandomState(random_seed)
|
|
||||||
manifest = read_manifest(manifest_path)
|
|
||||||
if num_samples == -1:
|
|
||||||
sampled_manifest = manifest
|
|
||||||
else:
|
|
||||||
sampled_manifest = self._rng.choice(
|
|
||||||
manifest, num_samples, replace=False)
|
|
||||||
self.items = sampled_manifest
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.items)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
return self.items[idx]
|
|
||||||
|
|
||||||
|
|
||||||
class FeatureNormalizer(object):
|
|
||||||
"""Feature normalizer. Normalize features to be of zero mean and unit
|
|
||||||
stddev.
|
|
||||||
|
|
||||||
if mean_std_filepath is provided (not None), the normalizer will directly
|
|
||||||
initilize from the file. Otherwise, both manifest_path and featurize_func
|
|
||||||
should be given for on-the-fly mean and stddev computing.
|
|
||||||
|
|
||||||
:param mean_std_filepath: File containing the pre-computed mean and stddev.
|
|
||||||
:type mean_std_filepath: None|str
|
|
||||||
:param manifest_path: Manifest of instances for computing mean and stddev.
|
|
||||||
:type meanifest_path: None|str
|
|
||||||
:param featurize_func: Function to extract features. It should be callable
|
|
||||||
with ``featurize_func(audio_segment)``.
|
|
||||||
:type featurize_func: None|callable
|
|
||||||
:param num_samples: Number of random samples for computing mean and stddev.
|
|
||||||
:type num_samples: int
|
|
||||||
:param random_seed: Random seed for sampling instances.
|
|
||||||
:type random_seed: int
|
|
||||||
:raises ValueError: If both mean_std_filepath and manifest_path
|
|
||||||
(or both mean_std_filepath and featurize_func) are None.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
mean_std_filepath,
|
|
||||||
manifest_path=None,
|
|
||||||
featurize_func=None,
|
|
||||||
num_samples=500,
|
|
||||||
num_workers=0,
|
|
||||||
random_seed=0):
|
|
||||||
if not mean_std_filepath:
|
|
||||||
if not (manifest_path and featurize_func):
|
|
||||||
raise ValueError("If mean_std_filepath is None, meanifest_path "
|
|
||||||
"and featurize_func should not be None.")
|
|
||||||
self._rng = np.random.RandomState(random_seed)
|
|
||||||
self._compute_mean_std(manifest_path, featurize_func, num_samples,
|
|
||||||
num_workers)
|
|
||||||
else:
|
|
||||||
self._read_mean_std_from_file(mean_std_filepath)
|
|
||||||
|
|
||||||
def apply(self, features):
|
|
||||||
"""Normalize features to be of zero mean and unit stddev.
|
|
||||||
|
|
||||||
:param features: Input features to be normalized.
|
|
||||||
:type features: ndarray, shape (T, D)
|
|
||||||
:param eps: added to stddev to provide numerical stablibity.
|
|
||||||
:type eps: float
|
|
||||||
:return: Normalized features.
|
|
||||||
:rtype: ndarray
|
|
||||||
"""
|
|
||||||
return (features - self._mean) * self._istd
|
|
||||||
|
|
||||||
def _read_mean_std_from_file(self, filepath, eps=1e-20):
|
|
||||||
"""Load mean and std from file."""
|
|
||||||
mean, istd = load_cmvn(filepath, filetype='json')
|
|
||||||
self._mean = np.expand_dims(mean, axis=0)
|
|
||||||
self._istd = np.expand_dims(istd, axis=0)
|
|
||||||
'''
|
|
||||||
print ("filepath", filepath)
|
|
||||||
npz = np.load(filepath)
|
|
||||||
self._mean = npz['mean'].reshape(1,161)
|
|
||||||
self._istd = npz['std'].reshape(1,161)
|
|
||||||
print ("mean.shape", self._mean.shape)
|
|
||||||
print ("istd.shape", self._istd.shape)
|
|
||||||
'''
|
|
||||||
|
|
||||||
def write_to_file(self, filepath):
|
|
||||||
"""Write the mean and stddev to the file.
|
|
||||||
|
|
||||||
:param filepath: File to write mean and stddev.
|
|
||||||
:type filepath: str
|
|
||||||
"""
|
|
||||||
with open(filepath, 'w') as fout:
|
|
||||||
fout.write(json.dumps(self.cmvn_info))
|
|
||||||
|
|
||||||
def _compute_mean_std(self,
|
|
||||||
manifest_path,
|
|
||||||
featurize_func,
|
|
||||||
num_samples,
|
|
||||||
num_workers,
|
|
||||||
batch_size=64,
|
|
||||||
eps=1e-20):
|
|
||||||
"""Compute mean and std from randomly sampled instances."""
|
|
||||||
paddle.set_device('cpu')
|
|
||||||
|
|
||||||
collate_func = CollateFunc(featurize_func)
|
|
||||||
dataset = AudioDataset(manifest_path, num_samples, self._rng)
|
|
||||||
data_loader = DataLoader(
|
|
||||||
dataset,
|
|
||||||
batch_size=batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=num_workers,
|
|
||||||
collate_fn=collate_func)
|
|
||||||
|
|
||||||
with paddle.no_grad():
|
|
||||||
all_mean_stat = None
|
|
||||||
all_var_stat = None
|
|
||||||
all_number = 0
|
|
||||||
wav_number = 0
|
|
||||||
for i, batch in enumerate(data_loader):
|
|
||||||
number, mean_stat, var_stat = batch
|
|
||||||
if i == 0:
|
|
||||||
all_mean_stat = mean_stat
|
|
||||||
all_var_stat = var_stat
|
|
||||||
else:
|
|
||||||
all_mean_stat += mean_stat
|
|
||||||
all_var_stat += var_stat
|
|
||||||
all_number += number
|
|
||||||
wav_number += batch_size
|
|
||||||
|
|
||||||
if wav_number % 1000 == 0:
|
|
||||||
logger.info(
|
|
||||||
f'process {wav_number} wavs,{all_number} frames.')
|
|
||||||
|
|
||||||
self.cmvn_info = {
|
|
||||||
'mean_stat': list(all_mean_stat.tolist()),
|
|
||||||
'var_stat': list(all_var_stat.tolist()),
|
|
||||||
'frame_num': all_number,
|
|
||||||
}
|
|
||||||
|
|
||||||
return self.cmvn_info
|
|
||||||
@ -1,217 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains the speech segment class."""
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from deepspeech.frontend.audio import AudioSegment
|
|
||||||
|
|
||||||
|
|
||||||
class SpeechSegment(AudioSegment):
|
|
||||||
"""Speech Segment with Text
|
|
||||||
|
|
||||||
Args:
|
|
||||||
AudioSegment (AudioSegment): Audio Segment
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
samples,
|
|
||||||
sample_rate,
|
|
||||||
transcript,
|
|
||||||
tokens=None,
|
|
||||||
token_ids=None):
|
|
||||||
"""Speech segment abstraction, a subclass of AudioSegment,
|
|
||||||
with an additional transcript.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
samples (ndarray.float32): Audio samples [num_samples x num_channels].
|
|
||||||
sample_rate (int): Audio sample rate.
|
|
||||||
transcript (str): Transcript text for the speech.
|
|
||||||
tokens (List[str], optinal): Transcript tokens for the speech.
|
|
||||||
token_ids (List[int], optional): Transcript token ids for the speech.
|
|
||||||
"""
|
|
||||||
AudioSegment.__init__(self, samples, sample_rate)
|
|
||||||
self._transcript = transcript
|
|
||||||
# must init `tokens` with `token_ids` at the same time
|
|
||||||
self._tokens = tokens
|
|
||||||
self._token_ids = token_ids
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
"""Return whether two objects are equal.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True, when equal to other
|
|
||||||
"""
|
|
||||||
if not AudioSegment.__eq__(self, other):
|
|
||||||
return False
|
|
||||||
if self._transcript != other._transcript:
|
|
||||||
return False
|
|
||||||
if self.has_token and other.has_token:
|
|
||||||
if self._tokens != other._tokens:
|
|
||||||
return False
|
|
||||||
if self._token_ids != other._token_ids:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def __ne__(self, other):
|
|
||||||
"""Return whether two objects are unequal."""
|
|
||||||
return not self.__eq__(other)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_file(cls, filepath, transcript, tokens=None, token_ids=None):
|
|
||||||
"""Create speech segment from audio file and corresponding transcript.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filepath (str|file): Filepath or file object to audio file.
|
|
||||||
transcript (str): Transcript text for the speech.
|
|
||||||
tokens (List[str], optional): text tokens. Defaults to None.
|
|
||||||
token_ids (List[int], optional): text token ids. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SpeechSegment: Speech segment instance.
|
|
||||||
"""
|
|
||||||
|
|
||||||
audio = AudioSegment.from_file(filepath)
|
|
||||||
return cls(audio.samples, audio.sample_rate, transcript, tokens,
|
|
||||||
token_ids)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_bytes(cls, bytes, transcript, tokens=None, token_ids=None):
|
|
||||||
"""Create speech segment from a byte string and corresponding
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filepath (str|file): Filepath or file object to audio file.
|
|
||||||
transcript (str): Transcript text for the speech.
|
|
||||||
tokens (List[str], optional): text tokens. Defaults to None.
|
|
||||||
token_ids (List[int], optional): text token ids. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SpeechSegment: Speech segment instance.
|
|
||||||
"""
|
|
||||||
audio = AudioSegment.from_bytes(bytes)
|
|
||||||
return cls(audio.samples, audio.sample_rate, transcript, tokens,
|
|
||||||
token_ids)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def concatenate(cls, *segments):
|
|
||||||
"""Concatenate an arbitrary number of speech segments together, both
|
|
||||||
audio and transcript will be concatenated.
|
|
||||||
|
|
||||||
:param *segments: Input speech segments to be concatenated.
|
|
||||||
:type *segments: tuple of SpeechSegment
|
|
||||||
:return: Speech segment instance.
|
|
||||||
:rtype: SpeechSegment
|
|
||||||
:raises ValueError: If the number of segments is zero, or if the
|
|
||||||
sample_rate of any two segments does not match.
|
|
||||||
:raises TypeError: If any segment is not SpeechSegment instance.
|
|
||||||
"""
|
|
||||||
if len(segments) == 0:
|
|
||||||
raise ValueError("No speech segments are given to concatenate.")
|
|
||||||
sample_rate = segments[0]._sample_rate
|
|
||||||
transcripts = ""
|
|
||||||
tokens = []
|
|
||||||
token_ids = []
|
|
||||||
for seg in segments:
|
|
||||||
if sample_rate != seg._sample_rate:
|
|
||||||
raise ValueError("Can't concatenate segments with "
|
|
||||||
"different sample rates")
|
|
||||||
if type(seg) is not cls:
|
|
||||||
raise TypeError("Only speech segments of the same type "
|
|
||||||
"instance can be concatenated.")
|
|
||||||
transcripts += seg._transcript
|
|
||||||
if self.has_token:
|
|
||||||
tokens += seg._tokens
|
|
||||||
token_ids += seg._token_ids
|
|
||||||
samples = np.concatenate([seg.samples for seg in segments])
|
|
||||||
return cls(samples, sample_rate, transcripts, tokens, token_ids)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def slice_from_file(cls,
|
|
||||||
filepath,
|
|
||||||
transcript,
|
|
||||||
tokens=None,
|
|
||||||
token_ids=None,
|
|
||||||
start=None,
|
|
||||||
end=None):
|
|
||||||
"""Loads a small section of an speech without having to load
|
|
||||||
the entire file into the memory which can be incredibly wasteful.
|
|
||||||
|
|
||||||
:param filepath: Filepath or file object to audio file.
|
|
||||||
:type filepath: str|file
|
|
||||||
:param start: Start time in seconds. If start is negative, it wraps
|
|
||||||
around from the end. If not provided, this function
|
|
||||||
reads from the very beginning.
|
|
||||||
:type start: float
|
|
||||||
:param end: End time in seconds. If end is negative, it wraps around
|
|
||||||
from the end. If not provided, the default behvaior is
|
|
||||||
to read to the end of the file.
|
|
||||||
:type end: float
|
|
||||||
:param transcript: Transcript text for the speech. if not provided,
|
|
||||||
the defaults is an empty string.
|
|
||||||
:type transript: str
|
|
||||||
:return: SpeechSegment instance of the specified slice of the input
|
|
||||||
speech file.
|
|
||||||
:rtype: SpeechSegment
|
|
||||||
"""
|
|
||||||
audio = AudioSegment.slice_from_file(filepath, start, end)
|
|
||||||
return cls(audio.samples, audio.sample_rate, transcript, tokens,
|
|
||||||
token_ids)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def make_silence(cls, duration, sample_rate):
|
|
||||||
"""Creates a silent speech segment of the given duration and
|
|
||||||
sample rate, transcript will be an empty string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
duration (float): Length of silence in seconds.
|
|
||||||
sample_rate (float): Sample rate.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SpeechSegment: Silence of the given duration.
|
|
||||||
"""
|
|
||||||
audio = AudioSegment.make_silence(duration, sample_rate)
|
|
||||||
return cls(audio.samples, audio.sample_rate, "")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def has_token(self):
|
|
||||||
if self._tokens and self._token_ids:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def transcript(self):
|
|
||||||
"""Return the transcript text.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Transcript text for the speech.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return self._transcript
|
|
||||||
|
|
||||||
@property
|
|
||||||
def tokens(self):
|
|
||||||
"""Return the transcript text tokens.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: text tokens.
|
|
||||||
"""
|
|
||||||
return self._tokens
|
|
||||||
|
|
||||||
@property
|
|
||||||
def token_ids(self):
|
|
||||||
"""Return the transcript text token ids.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[int]: text token ids.
|
|
||||||
"""
|
|
||||||
return self._token_ids
|
|
||||||
@ -1,289 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains data helper functions."""
|
|
||||||
import codecs
|
|
||||||
import json
|
|
||||||
import math
|
|
||||||
from typing import List
|
|
||||||
from typing import Optional
|
|
||||||
from typing import Text
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"load_dict", "load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs",
|
|
||||||
"max_dbfs", "mean_dbfs", "gain_db_to_ratio", "normalize_audio", "SOS",
|
|
||||||
"EOS", "UNK", "BLANK", "MASKCTC"
|
|
||||||
]
|
|
||||||
|
|
||||||
IGNORE_ID = -1
|
|
||||||
# `sos` and `eos` using same token
|
|
||||||
SOS = "<eos>"
|
|
||||||
EOS = SOS
|
|
||||||
UNK = "<unk>"
|
|
||||||
BLANK = "<blank>"
|
|
||||||
MASKCTC = "<mask>"
|
|
||||||
|
|
||||||
|
|
||||||
def load_dict(dict_path: Optional[Text], maskctc=False) -> Optional[List[Text]]:
|
|
||||||
if dict_path is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
with open(dict_path, "r") as f:
|
|
||||||
dictionary = f.readlines()
|
|
||||||
char_list = [entry.strip().split(" ")[0] for entry in dictionary]
|
|
||||||
if BLANK not in char_list:
|
|
||||||
char_list.insert(0, BLANK)
|
|
||||||
if EOS not in char_list:
|
|
||||||
char_list.append(EOS)
|
|
||||||
# for non-autoregressive maskctc model
|
|
||||||
if maskctc and MASKCTC not in char_list:
|
|
||||||
char_list.append(MASKCTC)
|
|
||||||
return char_list
|
|
||||||
|
|
||||||
|
|
||||||
def read_manifest(
|
|
||||||
manifest_path,
|
|
||||||
max_input_len=float('inf'),
|
|
||||||
min_input_len=0.0,
|
|
||||||
max_output_len=float('inf'),
|
|
||||||
min_output_len=0.0,
|
|
||||||
max_output_input_ratio=float('inf'),
|
|
||||||
min_output_input_ratio=0.0, ):
|
|
||||||
"""Load and parse manifest file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
manifest_path ([type]): Manifest file to load and parse.
|
|
||||||
max_input_len ([type], optional): maximum output seq length,
|
|
||||||
in seconds for raw wav, in frame numbers for feature data.
|
|
||||||
Defaults to float('inf').
|
|
||||||
min_input_len (float, optional): minimum input seq length,
|
|
||||||
in seconds for raw wav, in frame numbers for feature data.
|
|
||||||
Defaults to 0.0.
|
|
||||||
max_output_len (float, optional): maximum input seq length,
|
|
||||||
in modeling units. Defaults to 500.0.
|
|
||||||
min_output_len (float, optional): minimum input seq length,
|
|
||||||
in modeling units. Defaults to 0.0.
|
|
||||||
max_output_input_ratio (float, optional):
|
|
||||||
maximum output seq length/output seq length ratio. Defaults to 10.0.
|
|
||||||
min_output_input_ratio (float, optional):
|
|
||||||
minimum output seq length/output seq length ratio. Defaults to 0.05.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
IOError: If failed to parse the manifest.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[dict]: Manifest parsing results.
|
|
||||||
"""
|
|
||||||
|
|
||||||
manifest = []
|
|
||||||
for json_line in codecs.open(manifest_path, 'r', 'utf-8'):
|
|
||||||
try:
|
|
||||||
json_data = json.loads(json_line)
|
|
||||||
except Exception as e:
|
|
||||||
raise IOError("Error reading manifest: %s" % str(e))
|
|
||||||
|
|
||||||
feat_len = json_data["feat_shape"][
|
|
||||||
0] if 'feat_shape' in json_data else 1.0
|
|
||||||
token_len = json_data["token_shape"][
|
|
||||||
0] if 'token_shape' in json_data else 1.0
|
|
||||||
conditions = [
|
|
||||||
feat_len >= min_input_len,
|
|
||||||
feat_len <= max_input_len,
|
|
||||||
token_len >= min_output_len,
|
|
||||||
token_len <= max_output_len,
|
|
||||||
token_len / feat_len >= min_output_input_ratio,
|
|
||||||
token_len / feat_len <= max_output_input_ratio,
|
|
||||||
]
|
|
||||||
if all(conditions):
|
|
||||||
manifest.append(json_data)
|
|
||||||
return manifest
|
|
||||||
|
|
||||||
|
|
||||||
def rms_to_db(rms: float):
|
|
||||||
"""Root Mean Square to dB.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rms ([float]): root mean square
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: dB
|
|
||||||
"""
|
|
||||||
return 20.0 * math.log10(max(1e-16, rms))
|
|
||||||
|
|
||||||
|
|
||||||
def rms_to_dbfs(rms: float):
|
|
||||||
"""Root Mean Square to dBFS.
|
|
||||||
https://fireattack.wordpress.com/2017/02/06/replaygain-loudness-normalization-and-applications/
|
|
||||||
Audio is mix of sine wave, so 1 amp sine wave's Full scale is 0.7071, equal to -3.0103dB.
|
|
||||||
|
|
||||||
dB = dBFS + 3.0103
|
|
||||||
dBFS = db - 3.0103
|
|
||||||
e.g. 0 dB = -3.0103 dBFS
|
|
||||||
|
|
||||||
Args:
|
|
||||||
rms ([float]): root mean square
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: dBFS
|
|
||||||
"""
|
|
||||||
return rms_to_db(rms) - 3.0103
|
|
||||||
|
|
||||||
|
|
||||||
def max_dbfs(sample_data: np.ndarray):
|
|
||||||
"""Peak dBFS based on the maximum energy sample.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sample_data ([np.ndarray]): float array, [-1, 1].
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: dBFS
|
|
||||||
"""
|
|
||||||
# Peak dBFS based on the maximum energy sample. Will prevent overdrive if used for normalization.
|
|
||||||
return rms_to_dbfs(max(abs(np.min(sample_data)), abs(np.max(sample_data))))
|
|
||||||
|
|
||||||
|
|
||||||
def mean_dbfs(sample_data):
|
|
||||||
"""Peak dBFS based on the RMS energy.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sample_data ([np.ndarray]): float array, [-1, 1].
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: dBFS
|
|
||||||
"""
|
|
||||||
return rms_to_dbfs(
|
|
||||||
math.sqrt(np.mean(np.square(sample_data, dtype=np.float64))))
|
|
||||||
|
|
||||||
|
|
||||||
def gain_db_to_ratio(gain_db: float):
|
|
||||||
"""dB to ratio
|
|
||||||
|
|
||||||
Args:
|
|
||||||
gain_db (float): gain in dB
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: scale in amp
|
|
||||||
"""
|
|
||||||
return math.pow(10.0, gain_db / 20.0)
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_audio(sample_data: np.ndarray, dbfs: float=-3.0103):
|
|
||||||
"""Nomalize audio to dBFS.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sample_data (np.ndarray): input wave samples, [-1, 1].
|
|
||||||
dbfs (float, optional): target dBFS. Defaults to -3.0103.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: normalized wave
|
|
||||||
"""
|
|
||||||
return np.maximum(
|
|
||||||
np.minimum(sample_data * gain_db_to_ratio(dbfs - max_dbfs(sample_data)),
|
|
||||||
1.0), -1.0)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_json_cmvn(json_cmvn_file):
|
|
||||||
""" Load the json format cmvn stats file and calculate cmvn
|
|
||||||
|
|
||||||
Args:
|
|
||||||
json_cmvn_file: cmvn stats file in json format
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
a numpy array of [means, vars]
|
|
||||||
"""
|
|
||||||
with open(json_cmvn_file) as f:
|
|
||||||
cmvn_stats = json.load(f)
|
|
||||||
|
|
||||||
means = cmvn_stats['mean_stat']
|
|
||||||
variance = cmvn_stats['var_stat']
|
|
||||||
count = cmvn_stats['frame_num']
|
|
||||||
for i in range(len(means)):
|
|
||||||
means[i] /= count
|
|
||||||
variance[i] = variance[i] / count - means[i] * means[i]
|
|
||||||
if variance[i] < 1.0e-20:
|
|
||||||
variance[i] = 1.0e-20
|
|
||||||
variance[i] = 1.0 / math.sqrt(variance[i])
|
|
||||||
cmvn = np.array([means, variance])
|
|
||||||
return cmvn
|
|
||||||
|
|
||||||
|
|
||||||
def _load_kaldi_cmvn(kaldi_cmvn_file):
|
|
||||||
""" Load the kaldi format cmvn stats file and calculate cmvn
|
|
||||||
|
|
||||||
Args:
|
|
||||||
kaldi_cmvn_file: kaldi text style global cmvn file, which
|
|
||||||
is generated by:
|
|
||||||
compute-cmvn-stats --binary=false scp:feats.scp global_cmvn
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
a numpy array of [means, vars]
|
|
||||||
"""
|
|
||||||
means = []
|
|
||||||
variance = []
|
|
||||||
with open(kaldi_cmvn_file, 'r') as fid:
|
|
||||||
# kaldi binary file start with '\0B'
|
|
||||||
if fid.read(2) == '\0B':
|
|
||||||
logger.error('kaldi cmvn binary file is not supported, please '
|
|
||||||
'recompute it by: compute-cmvn-stats --binary=false '
|
|
||||||
' scp:feats.scp global_cmvn')
|
|
||||||
sys.exit(1)
|
|
||||||
fid.seek(0)
|
|
||||||
arr = fid.read().split()
|
|
||||||
assert (arr[0] == '[')
|
|
||||||
assert (arr[-2] == '0')
|
|
||||||
assert (arr[-1] == ']')
|
|
||||||
feat_dim = int((len(arr) - 2 - 2) / 2)
|
|
||||||
for i in range(1, feat_dim + 1):
|
|
||||||
means.append(float(arr[i]))
|
|
||||||
count = float(arr[feat_dim + 1])
|
|
||||||
for i in range(feat_dim + 2, 2 * feat_dim + 2):
|
|
||||||
variance.append(float(arr[i]))
|
|
||||||
|
|
||||||
for i in range(len(means)):
|
|
||||||
means[i] /= count
|
|
||||||
variance[i] = variance[i] / count - means[i] * means[i]
|
|
||||||
if variance[i] < 1.0e-20:
|
|
||||||
variance[i] = 1.0e-20
|
|
||||||
variance[i] = 1.0 / math.sqrt(variance[i])
|
|
||||||
cmvn = np.array([means, variance])
|
|
||||||
return cmvn
|
|
||||||
|
|
||||||
|
|
||||||
def load_cmvn(cmvn_file: str, filetype: str):
|
|
||||||
"""load cmvn from file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cmvn_file (str): cmvn path.
|
|
||||||
filetype (str): file type, optional[npz, json, kaldi].
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: file type not support.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[np.ndarray, np.ndarray]: mean, istd
|
|
||||||
"""
|
|
||||||
assert filetype in ['npz', 'json', 'kaldi'], filetype
|
|
||||||
filetype = filetype.lower()
|
|
||||||
if filetype == "json":
|
|
||||||
cmvn = _load_json_cmvn(cmvn_file)
|
|
||||||
elif filetype == "kaldi":
|
|
||||||
cmvn = _load_kaldi_cmvn(cmvn_file)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"cmvn file type no support: {filetype}")
|
|
||||||
return cmvn[0], cmvn[1]
|
|
||||||
@ -1,13 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
@ -1,469 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import itertools
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
|
|
||||||
__all__ = ["make_batchset"]
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
|
|
||||||
def batchfy_by_seq(
|
|
||||||
sorted_data,
|
|
||||||
batch_size,
|
|
||||||
max_length_in,
|
|
||||||
max_length_out,
|
|
||||||
min_batch_size=1,
|
|
||||||
shortest_first=False,
|
|
||||||
ikey="input",
|
|
||||||
iaxis=0,
|
|
||||||
okey="output",
|
|
||||||
oaxis=0, ):
|
|
||||||
"""Make batch set from json dictionary
|
|
||||||
|
|
||||||
:param List[(str, Dict[str, Any])] sorted_data: dictionary loaded from data.json
|
|
||||||
:param int batch_size: batch size
|
|
||||||
:param int max_length_in: maximum length of input to decide adaptive batch size
|
|
||||||
:param int max_length_out: maximum length of output to decide adaptive batch size
|
|
||||||
:param int min_batch_size: mininum batch size (for multi-gpu)
|
|
||||||
:param bool shortest_first: Sort from batch with shortest samples
|
|
||||||
to longest if true, otherwise reverse
|
|
||||||
:param str ikey: key to access input
|
|
||||||
(for ASR ikey="input", for TTS, MT ikey="output".)
|
|
||||||
:param int iaxis: dimension to access input
|
|
||||||
(for ASR, TTS iaxis=0, for MT iaxis="1".)
|
|
||||||
:param str okey: key to access output
|
|
||||||
(for ASR, MT okey="output". for TTS okey="input".)
|
|
||||||
:param int oaxis: dimension to access output
|
|
||||||
(for ASR, TTS, MT oaxis=0, reserved for future research, -1 means all axis.)
|
|
||||||
:return: List[List[Tuple[str, dict]]] list of batches
|
|
||||||
"""
|
|
||||||
if batch_size <= 0:
|
|
||||||
raise ValueError(f"Invalid batch_size={batch_size}")
|
|
||||||
|
|
||||||
# check #utts is more than min_batch_size
|
|
||||||
if len(sorted_data) < min_batch_size:
|
|
||||||
raise ValueError(
|
|
||||||
f"#utts({len(sorted_data)}) is less than min_batch_size({min_batch_size})."
|
|
||||||
)
|
|
||||||
|
|
||||||
# make list of minibatches
|
|
||||||
minibatches = []
|
|
||||||
start = 0
|
|
||||||
while True:
|
|
||||||
_, info = sorted_data[start]
|
|
||||||
ilen = int(info[ikey][iaxis]["shape"][0])
|
|
||||||
olen = (int(info[okey][oaxis]["shape"][0]) if oaxis >= 0 else
|
|
||||||
max(map(lambda x: int(x["shape"][0]), info[okey])))
|
|
||||||
factor = max(int(ilen / max_length_in), int(olen / max_length_out))
|
|
||||||
# change batchsize depending on the input and output length
|
|
||||||
# if ilen = 1000 and max_length_in = 800
|
|
||||||
# then b = batchsize / 2
|
|
||||||
# and max(min_batches, .) avoids batchsize = 0
|
|
||||||
bs = max(min_batch_size, int(batch_size / (1 + factor)))
|
|
||||||
end = min(len(sorted_data), start + bs)
|
|
||||||
minibatch = sorted_data[start:end]
|
|
||||||
if shortest_first:
|
|
||||||
minibatch.reverse()
|
|
||||||
|
|
||||||
# check each batch is more than minimum batchsize
|
|
||||||
if len(minibatch) < min_batch_size:
|
|
||||||
mod = min_batch_size - len(minibatch) % min_batch_size
|
|
||||||
additional_minibatch = [
|
|
||||||
sorted_data[i] for i in np.random.randint(0, start, mod)
|
|
||||||
]
|
|
||||||
if shortest_first:
|
|
||||||
additional_minibatch.reverse()
|
|
||||||
minibatch.extend(additional_minibatch)
|
|
||||||
minibatches.append(minibatch)
|
|
||||||
|
|
||||||
if end == len(sorted_data):
|
|
||||||
break
|
|
||||||
start = end
|
|
||||||
|
|
||||||
# batch: List[List[Tuple[str, dict]]]
|
|
||||||
return minibatches
|
|
||||||
|
|
||||||
|
|
||||||
def batchfy_by_bin(
|
|
||||||
sorted_data,
|
|
||||||
batch_bins,
|
|
||||||
num_batches=0,
|
|
||||||
min_batch_size=1,
|
|
||||||
shortest_first=False,
|
|
||||||
ikey="input",
|
|
||||||
okey="output", ):
|
|
||||||
"""Make variably sized batch set, which maximizes
|
|
||||||
|
|
||||||
the number of bins up to `batch_bins`.
|
|
||||||
|
|
||||||
:param List[(str, Dict[str, Any])] sorted_data: dictionary loaded from data.json
|
|
||||||
:param int batch_bins: Maximum frames of a batch
|
|
||||||
:param int num_batches: # number of batches to use (for debug)
|
|
||||||
:param int min_batch_size: minimum batch size (for multi-gpu)
|
|
||||||
:param int test: Return only every `test` batches
|
|
||||||
:param bool shortest_first: Sort from batch with shortest samples
|
|
||||||
to longest if true, otherwise reverse
|
|
||||||
|
|
||||||
:param str ikey: key to access input (for ASR ikey="input", for TTS ikey="output".)
|
|
||||||
:param str okey: key to access output (for ASR okey="output". for TTS okey="input".)
|
|
||||||
|
|
||||||
:return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches
|
|
||||||
"""
|
|
||||||
if batch_bins <= 0:
|
|
||||||
raise ValueError(f"invalid batch_bins={batch_bins}")
|
|
||||||
length = len(sorted_data)
|
|
||||||
idim = int(sorted_data[0][1][ikey][0]["shape"][1])
|
|
||||||
odim = int(sorted_data[0][1][okey][0]["shape"][1])
|
|
||||||
logger.info("# utts: " + str(len(sorted_data)))
|
|
||||||
minibatches = []
|
|
||||||
start = 0
|
|
||||||
n = 0
|
|
||||||
while True:
|
|
||||||
# Dynamic batch size depending on size of samples
|
|
||||||
b = 0
|
|
||||||
next_size = 0
|
|
||||||
max_olen = 0
|
|
||||||
while next_size < batch_bins and (start + b) < length:
|
|
||||||
ilen = int(sorted_data[start + b][1][ikey][0]["shape"][0]) * idim
|
|
||||||
olen = int(sorted_data[start + b][1][okey][0]["shape"][0]) * odim
|
|
||||||
if olen > max_olen:
|
|
||||||
max_olen = olen
|
|
||||||
next_size = (max_olen + ilen) * (b + 1)
|
|
||||||
if next_size <= batch_bins:
|
|
||||||
b += 1
|
|
||||||
elif next_size == 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"Can't fit one sample in batch_bins ({batch_bins}): "
|
|
||||||
f"Please increase the value")
|
|
||||||
end = min(length, start + max(min_batch_size, b))
|
|
||||||
batch = sorted_data[start:end]
|
|
||||||
if shortest_first:
|
|
||||||
batch.reverse()
|
|
||||||
minibatches.append(batch)
|
|
||||||
# Check for min_batch_size and fixes the batches if needed
|
|
||||||
i = -1
|
|
||||||
while len(minibatches[i]) < min_batch_size:
|
|
||||||
missing = min_batch_size - len(minibatches[i])
|
|
||||||
if -i == len(minibatches):
|
|
||||||
minibatches[i + 1].extend(minibatches[i])
|
|
||||||
minibatches = minibatches[1:]
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
minibatches[i].extend(minibatches[i - 1][:missing])
|
|
||||||
minibatches[i - 1] = minibatches[i - 1][missing:]
|
|
||||||
i -= 1
|
|
||||||
if end == length:
|
|
||||||
break
|
|
||||||
start = end
|
|
||||||
n += 1
|
|
||||||
if num_batches > 0:
|
|
||||||
minibatches = minibatches[:num_batches]
|
|
||||||
lengths = [len(x) for x in minibatches]
|
|
||||||
logger.info(
|
|
||||||
str(len(minibatches)) + " batches containing from " + str(min(lengths))
|
|
||||||
+ " to " + str(max(lengths)) + " samples " + "(avg " + str(
|
|
||||||
int(np.mean(lengths))) + " samples).")
|
|
||||||
return minibatches
|
|
||||||
|
|
||||||
|
|
||||||
def batchfy_by_frame(
|
|
||||||
sorted_data,
|
|
||||||
max_frames_in,
|
|
||||||
max_frames_out,
|
|
||||||
max_frames_inout,
|
|
||||||
num_batches=0,
|
|
||||||
min_batch_size=1,
|
|
||||||
shortest_first=False,
|
|
||||||
ikey="input",
|
|
||||||
okey="output", ):
|
|
||||||
"""Make variable batch set, which maximizes the number of frames to max_batch_frame.
|
|
||||||
|
|
||||||
:param List[(str, Dict[str, Any])] sorteddata: dictionary loaded from data.json
|
|
||||||
:param int max_frames_in: Maximum input frames of a batch
|
|
||||||
:param int max_frames_out: Maximum output frames of a batch
|
|
||||||
:param int max_frames_inout: Maximum input+output frames of a batch
|
|
||||||
:param int num_batches: # number of batches to use (for debug)
|
|
||||||
:param int min_batch_size: minimum batch size (for multi-gpu)
|
|
||||||
:param int test: Return only every `test` batches
|
|
||||||
:param bool shortest_first: Sort from batch with shortest samples
|
|
||||||
to longest if true, otherwise reverse
|
|
||||||
|
|
||||||
:param str ikey: key to access input (for ASR ikey="input", for TTS ikey="output".)
|
|
||||||
:param str okey: key to access output (for ASR okey="output". for TTS okey="input".)
|
|
||||||
|
|
||||||
:return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches
|
|
||||||
"""
|
|
||||||
if max_frames_in <= 0 and max_frames_out <= 0 and max_frames_inout <= 0:
|
|
||||||
raise ValueError(
|
|
||||||
"At least, one of `--batch-frames-in`, `--batch-frames-out` or "
|
|
||||||
"`--batch-frames-inout` should be > 0")
|
|
||||||
length = len(sorted_data)
|
|
||||||
minibatches = []
|
|
||||||
start = 0
|
|
||||||
end = 0
|
|
||||||
while end != length:
|
|
||||||
# Dynamic batch size depending on size of samples
|
|
||||||
b = 0
|
|
||||||
max_olen = 0
|
|
||||||
max_ilen = 0
|
|
||||||
while (start + b) < length:
|
|
||||||
ilen = int(sorted_data[start + b][1][ikey][0]["shape"][0])
|
|
||||||
if ilen > max_frames_in and max_frames_in != 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"Can't fit one sample in --batch-frames-in ({max_frames_in}): "
|
|
||||||
f"Please increase the value")
|
|
||||||
olen = int(sorted_data[start + b][1][okey][0]["shape"][0])
|
|
||||||
if olen > max_frames_out and max_frames_out != 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"Can't fit one sample in --batch-frames-out ({max_frames_out}): "
|
|
||||||
f"Please increase the value")
|
|
||||||
if ilen + olen > max_frames_inout and max_frames_inout != 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"Can't fit one sample in --batch-frames-out ({max_frames_inout}): "
|
|
||||||
f"Please increase the value")
|
|
||||||
max_olen = max(max_olen, olen)
|
|
||||||
max_ilen = max(max_ilen, ilen)
|
|
||||||
in_ok = max_ilen * (b + 1) <= max_frames_in or max_frames_in == 0
|
|
||||||
out_ok = max_olen * (b + 1) <= max_frames_out or max_frames_out == 0
|
|
||||||
inout_ok = (max_ilen + max_olen) * (
|
|
||||||
b + 1) <= max_frames_inout or max_frames_inout == 0
|
|
||||||
if in_ok and out_ok and inout_ok:
|
|
||||||
# add more seq in the minibatch
|
|
||||||
b += 1
|
|
||||||
else:
|
|
||||||
# no more seq in the minibatch
|
|
||||||
break
|
|
||||||
end = min(length, start + b)
|
|
||||||
batch = sorted_data[start:end]
|
|
||||||
if shortest_first:
|
|
||||||
batch.reverse()
|
|
||||||
minibatches.append(batch)
|
|
||||||
# Check for min_batch_size and fixes the batches if needed
|
|
||||||
i = -1
|
|
||||||
while len(minibatches[i]) < min_batch_size:
|
|
||||||
missing = min_batch_size - len(minibatches[i])
|
|
||||||
if -i == len(minibatches):
|
|
||||||
minibatches[i + 1].extend(minibatches[i])
|
|
||||||
minibatches = minibatches[1:]
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
minibatches[i].extend(minibatches[i - 1][:missing])
|
|
||||||
minibatches[i - 1] = minibatches[i - 1][missing:]
|
|
||||||
i -= 1
|
|
||||||
start = end
|
|
||||||
if num_batches > 0:
|
|
||||||
minibatches = minibatches[:num_batches]
|
|
||||||
lengths = [len(x) for x in minibatches]
|
|
||||||
logger.info(
|
|
||||||
str(len(minibatches)) + " batches containing from " + str(min(lengths))
|
|
||||||
+ " to " + str(max(lengths)) + " samples" + "(avg " + str(
|
|
||||||
int(np.mean(lengths))) + " samples).")
|
|
||||||
|
|
||||||
return minibatches
|
|
||||||
|
|
||||||
|
|
||||||
def batchfy_shuffle(data, batch_size, min_batch_size, num_batches,
|
|
||||||
shortest_first):
|
|
||||||
import random
|
|
||||||
|
|
||||||
logger.info("use shuffled batch.")
|
|
||||||
sorted_data = random.sample(data.items(), len(data.items()))
|
|
||||||
logger.info("# utts: " + str(len(sorted_data)))
|
|
||||||
# make list of minibatches
|
|
||||||
minibatches = []
|
|
||||||
start = 0
|
|
||||||
while True:
|
|
||||||
end = min(len(sorted_data), start + batch_size)
|
|
||||||
# check each batch is more than minimum batchsize
|
|
||||||
minibatch = sorted_data[start:end]
|
|
||||||
if shortest_first:
|
|
||||||
minibatch.reverse()
|
|
||||||
if len(minibatch) < min_batch_size:
|
|
||||||
mod = min_batch_size - len(minibatch) % min_batch_size
|
|
||||||
additional_minibatch = [
|
|
||||||
sorted_data[i] for i in np.random.randint(0, start, mod)
|
|
||||||
]
|
|
||||||
if shortest_first:
|
|
||||||
additional_minibatch.reverse()
|
|
||||||
minibatch.extend(additional_minibatch)
|
|
||||||
minibatches.append(minibatch)
|
|
||||||
if end == len(sorted_data):
|
|
||||||
break
|
|
||||||
start = end
|
|
||||||
|
|
||||||
# for debugging
|
|
||||||
if num_batches > 0:
|
|
||||||
minibatches = minibatches[:num_batches]
|
|
||||||
logger.info("# minibatches: " + str(len(minibatches)))
|
|
||||||
return minibatches
|
|
||||||
|
|
||||||
|
|
||||||
BATCH_COUNT_CHOICES = ["auto", "seq", "bin", "frame"]
|
|
||||||
BATCH_SORT_KEY_CHOICES = ["input", "output", "shuffle"]
|
|
||||||
|
|
||||||
|
|
||||||
def make_batchset(
|
|
||||||
data,
|
|
||||||
batch_size=0,
|
|
||||||
max_length_in=float("inf"),
|
|
||||||
max_length_out=float("inf"),
|
|
||||||
num_batches=0,
|
|
||||||
min_batch_size=1,
|
|
||||||
shortest_first=False,
|
|
||||||
batch_sort_key="input",
|
|
||||||
count="auto",
|
|
||||||
batch_bins=0,
|
|
||||||
batch_frames_in=0,
|
|
||||||
batch_frames_out=0,
|
|
||||||
batch_frames_inout=0,
|
|
||||||
iaxis=0,
|
|
||||||
oaxis=0, ):
|
|
||||||
"""Make batch set from json dictionary
|
|
||||||
|
|
||||||
if utts have "category" value,
|
|
||||||
|
|
||||||
>>> data = [{'category': 'A', 'input': ..., 'utt':'utt1'},
|
|
||||||
... {'category': 'B', 'input': ..., 'utt':'utt2'},
|
|
||||||
... {'category': 'B', 'input': ..., 'utt':'utt3'},
|
|
||||||
... {'category': 'A', 'input': ..., 'utt':'utt4'}]
|
|
||||||
>>> make_batchset(data, batchsize=2, ...)
|
|
||||||
[[('utt1', ...), ('utt4', ...)], [('utt2', ...), ('utt3': ...)]]
|
|
||||||
|
|
||||||
Note that if any utts doesn't have "category",
|
|
||||||
perform as same as batchfy_by_{count}
|
|
||||||
|
|
||||||
:param List[Dict[str, Any]] data: dictionary loaded from data.json
|
|
||||||
:param int batch_size: maximum number of sequences in a minibatch.
|
|
||||||
:param int batch_bins: maximum number of bins (frames x dim) in a minibatch.
|
|
||||||
:param int batch_frames_in: maximum number of input frames in a minibatch.
|
|
||||||
:param int batch_frames_out: maximum number of output frames in a minibatch.
|
|
||||||
:param int batch_frames_out: maximum number of input+output frames in a minibatch.
|
|
||||||
:param str count: strategy to count maximum size of batch.
|
|
||||||
For choices, see espnet.asr.batchfy.BATCH_COUNT_CHOICES
|
|
||||||
|
|
||||||
:param int max_length_in: maximum length of input to decide adaptive batch size
|
|
||||||
:param int max_length_out: maximum length of output to decide adaptive batch size
|
|
||||||
:param int num_batches: # number of batches to use (for debug)
|
|
||||||
:param int min_batch_size: minimum batch size (for multi-gpu)
|
|
||||||
:param bool shortest_first: Sort from batch with shortest samples
|
|
||||||
to longest if true, otherwise reverse
|
|
||||||
:param str batch_sort_key: how to sort data before creating minibatches
|
|
||||||
["input", "output", "shuffle"]
|
|
||||||
:param bool swap_io: if True, use "input" as output and "output"
|
|
||||||
as input in `data` dict
|
|
||||||
:param bool mt: if True, use 0-axis of "output" as output and 1-axis of "output"
|
|
||||||
as input in `data` dict
|
|
||||||
:param int iaxis: dimension to access input
|
|
||||||
(for ASR, TTS iaxis=0, for MT iaxis="1".)
|
|
||||||
:param int oaxis: dimension to access output (for ASR, TTS, MT oaxis=0,
|
|
||||||
reserved for future research, -1 means all axis.)
|
|
||||||
:return: List[List[Tuple[str, dict]]] list of batches
|
|
||||||
"""
|
|
||||||
# check args
|
|
||||||
if count not in BATCH_COUNT_CHOICES:
|
|
||||||
raise ValueError(
|
|
||||||
f"arg 'count' ({count}) should be one of {BATCH_COUNT_CHOICES}")
|
|
||||||
if batch_sort_key not in BATCH_SORT_KEY_CHOICES:
|
|
||||||
raise ValueError(f"arg 'batch_sort_key' ({batch_sort_key}) should be "
|
|
||||||
f"one of {BATCH_SORT_KEY_CHOICES}")
|
|
||||||
|
|
||||||
ikey = "input"
|
|
||||||
okey = "output"
|
|
||||||
batch_sort_axis = 0 # index of list
|
|
||||||
if count == "auto":
|
|
||||||
if batch_size != 0:
|
|
||||||
count = "seq"
|
|
||||||
elif batch_bins != 0:
|
|
||||||
count = "bin"
|
|
||||||
elif batch_frames_in != 0 or batch_frames_out != 0 or batch_frames_inout != 0:
|
|
||||||
count = "frame"
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"cannot detect `count` manually set one of {BATCH_COUNT_CHOICES}"
|
|
||||||
)
|
|
||||||
logger.info(f"count is auto detected as {count}")
|
|
||||||
|
|
||||||
if count != "seq" and batch_sort_key == "shuffle":
|
|
||||||
raise ValueError(
|
|
||||||
"batch_sort_key=shuffle is only available if batch_count=seq")
|
|
||||||
|
|
||||||
category2data = {} # Dict[str, dict]
|
|
||||||
for v in data:
|
|
||||||
k = v['utt']
|
|
||||||
category2data.setdefault(v.get("category"), {})[k] = v
|
|
||||||
|
|
||||||
batches_list = [] # List[List[List[Tuple[str, dict]]]]
|
|
||||||
for d in category2data.values():
|
|
||||||
if batch_sort_key == "shuffle":
|
|
||||||
batches = batchfy_shuffle(d, batch_size, min_batch_size,
|
|
||||||
num_batches, shortest_first)
|
|
||||||
batches_list.append(batches)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# sort it by input lengths (long to short)
|
|
||||||
sorted_data = sorted(
|
|
||||||
d.items(),
|
|
||||||
key=lambda data: int(data[1][batch_sort_key][batch_sort_axis]["shape"][0]),
|
|
||||||
reverse=not shortest_first, )
|
|
||||||
logger.info("# utts: " + str(len(sorted_data)))
|
|
||||||
|
|
||||||
if count == "seq":
|
|
||||||
batches = batchfy_by_seq(
|
|
||||||
sorted_data,
|
|
||||||
batch_size=batch_size,
|
|
||||||
max_length_in=max_length_in,
|
|
||||||
max_length_out=max_length_out,
|
|
||||||
min_batch_size=min_batch_size,
|
|
||||||
shortest_first=shortest_first,
|
|
||||||
ikey=ikey,
|
|
||||||
iaxis=iaxis,
|
|
||||||
okey=okey,
|
|
||||||
oaxis=oaxis, )
|
|
||||||
if count == "bin":
|
|
||||||
batches = batchfy_by_bin(
|
|
||||||
sorted_data,
|
|
||||||
batch_bins=batch_bins,
|
|
||||||
min_batch_size=min_batch_size,
|
|
||||||
shortest_first=shortest_first,
|
|
||||||
ikey=ikey,
|
|
||||||
okey=okey, )
|
|
||||||
if count == "frame":
|
|
||||||
batches = batchfy_by_frame(
|
|
||||||
sorted_data,
|
|
||||||
max_frames_in=batch_frames_in,
|
|
||||||
max_frames_out=batch_frames_out,
|
|
||||||
max_frames_inout=batch_frames_inout,
|
|
||||||
min_batch_size=min_batch_size,
|
|
||||||
shortest_first=shortest_first,
|
|
||||||
ikey=ikey,
|
|
||||||
okey=okey, )
|
|
||||||
batches_list.append(batches)
|
|
||||||
|
|
||||||
if len(batches_list) == 1:
|
|
||||||
batches = batches_list[0]
|
|
||||||
else:
|
|
||||||
# Concat list. This way is faster than "sum(batch_list, [])"
|
|
||||||
batches = list(itertools.chain(*batches_list))
|
|
||||||
|
|
||||||
# for debugging
|
|
||||||
if num_batches > 0:
|
|
||||||
batches = batches[:num_batches]
|
|
||||||
logger.info("# minibatches: " + str(len(batches)))
|
|
||||||
|
|
||||||
# batch: List[List[Tuple[str, dict]]]
|
|
||||||
return batches
|
|
||||||
@ -1,321 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import io
|
|
||||||
from collections import namedtuple
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from yacs.config import CfgNode
|
|
||||||
|
|
||||||
from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline
|
|
||||||
from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer
|
|
||||||
from deepspeech.frontend.normalizer import FeatureNormalizer
|
|
||||||
from deepspeech.frontend.speech import SpeechSegment
|
|
||||||
from deepspeech.frontend.utility import IGNORE_ID
|
|
||||||
from deepspeech.io.utility import pad_list
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
|
|
||||||
__all__ = ["SpeechCollator"]
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
# namedtupe need global for pickle.
|
|
||||||
TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object'])
|
|
||||||
|
|
||||||
|
|
||||||
class SpeechCollator():
|
|
||||||
@classmethod
|
|
||||||
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
|
|
||||||
default = CfgNode(
|
|
||||||
dict(
|
|
||||||
augmentation_config="",
|
|
||||||
random_seed=0,
|
|
||||||
mean_std_filepath="",
|
|
||||||
unit_type="char",
|
|
||||||
vocab_filepath="",
|
|
||||||
spm_model_prefix="",
|
|
||||||
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
|
|
||||||
feat_dim=0, # 'mfcc', 'fbank'
|
|
||||||
delta_delta=False, # 'mfcc', 'fbank'
|
|
||||||
stride_ms=10.0, # ms
|
|
||||||
window_ms=20.0, # ms
|
|
||||||
n_fft=None, # fft points
|
|
||||||
max_freq=None, # None for samplerate/2
|
|
||||||
target_sample_rate=16000, # target sample rate
|
|
||||||
use_dB_normalization=True,
|
|
||||||
target_dB=-20,
|
|
||||||
dither=1.0, # feature dither
|
|
||||||
keep_transcription_text=False))
|
|
||||||
|
|
||||||
if config is not None:
|
|
||||||
config.merge_from_other_cfg(default)
|
|
||||||
return default
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config):
|
|
||||||
"""Build a SpeechCollator object from a config.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (yacs.config.CfgNode): configs object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SpeechCollator: collator object.
|
|
||||||
"""
|
|
||||||
assert 'augmentation_config' in config.collator
|
|
||||||
assert 'keep_transcription_text' in config.collator
|
|
||||||
assert 'mean_std_filepath' in config.collator
|
|
||||||
assert 'vocab_filepath' in config.collator
|
|
||||||
assert 'specgram_type' in config.collator
|
|
||||||
assert 'n_fft' in config.collator
|
|
||||||
assert config.collator
|
|
||||||
|
|
||||||
if isinstance(config.collator.augmentation_config, (str, bytes)):
|
|
||||||
if config.collator.augmentation_config:
|
|
||||||
aug_file = io.open(
|
|
||||||
config.collator.augmentation_config,
|
|
||||||
mode='r',
|
|
||||||
encoding='utf8')
|
|
||||||
else:
|
|
||||||
aug_file = io.StringIO(initial_value='{}', newline='')
|
|
||||||
else:
|
|
||||||
aug_file = config.collator.augmentation_config
|
|
||||||
assert isinstance(aug_file, io.StringIO)
|
|
||||||
|
|
||||||
speech_collator = cls(
|
|
||||||
aug_file=aug_file,
|
|
||||||
random_seed=0,
|
|
||||||
mean_std_filepath=config.collator.mean_std_filepath,
|
|
||||||
unit_type=config.collator.unit_type,
|
|
||||||
vocab_filepath=config.collator.vocab_filepath,
|
|
||||||
spm_model_prefix=config.collator.spm_model_prefix,
|
|
||||||
specgram_type=config.collator.specgram_type,
|
|
||||||
feat_dim=config.collator.feat_dim,
|
|
||||||
delta_delta=config.collator.delta_delta,
|
|
||||||
stride_ms=config.collator.stride_ms,
|
|
||||||
window_ms=config.collator.window_ms,
|
|
||||||
n_fft=config.collator.n_fft,
|
|
||||||
max_freq=config.collator.max_freq,
|
|
||||||
target_sample_rate=config.collator.target_sample_rate,
|
|
||||||
use_dB_normalization=config.collator.use_dB_normalization,
|
|
||||||
target_dB=config.collator.target_dB,
|
|
||||||
dither=config.collator.dither,
|
|
||||||
keep_transcription_text=config.collator.keep_transcription_text)
|
|
||||||
return speech_collator
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
aug_file,
|
|
||||||
mean_std_filepath,
|
|
||||||
vocab_filepath,
|
|
||||||
spm_model_prefix,
|
|
||||||
random_seed=0,
|
|
||||||
unit_type="char",
|
|
||||||
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
|
|
||||||
feat_dim=0, # 'mfcc', 'fbank'
|
|
||||||
delta_delta=False, # 'mfcc', 'fbank'
|
|
||||||
stride_ms=10.0, # ms
|
|
||||||
window_ms=20.0, # ms
|
|
||||||
n_fft=None, # fft points
|
|
||||||
max_freq=None, # None for samplerate/2
|
|
||||||
target_sample_rate=16000, # target sample rate
|
|
||||||
use_dB_normalization=True,
|
|
||||||
target_dB=-20,
|
|
||||||
dither=1.0,
|
|
||||||
keep_transcription_text=True):
|
|
||||||
"""SpeechCollator Collator
|
|
||||||
|
|
||||||
Args:
|
|
||||||
unit_type(str): token unit type, e.g. char, word, spm
|
|
||||||
vocab_filepath (str): vocab file path.
|
|
||||||
mean_std_filepath (str): mean and std file path, which suffix is *.npy
|
|
||||||
spm_model_prefix (str): spm model prefix, need if `unit_type` is spm.
|
|
||||||
augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
|
|
||||||
stride_ms (float, optional): stride size in ms. Defaults to 10.0.
|
|
||||||
window_ms (float, optional): window size in ms. Defaults to 20.0.
|
|
||||||
n_fft (int, optional): fft points for rfft. Defaults to None.
|
|
||||||
max_freq (int, optional): max cut freq. Defaults to None.
|
|
||||||
target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000.
|
|
||||||
specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'.
|
|
||||||
feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None.
|
|
||||||
delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False.
|
|
||||||
use_dB_normalization (bool, optional): do dB normalization. Defaults to True.
|
|
||||||
target_dB (int, optional): target dB. Defaults to -20.
|
|
||||||
random_seed (int, optional): for random generator. Defaults to 0.
|
|
||||||
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
|
|
||||||
if ``keep_transcription_text`` is False, text is token ids else is raw string.
|
|
||||||
|
|
||||||
Do augmentations
|
|
||||||
Padding audio features with zeros to make them have the same shape (or
|
|
||||||
a user-defined shape) within one batch.
|
|
||||||
"""
|
|
||||||
self._keep_transcription_text = keep_transcription_text
|
|
||||||
|
|
||||||
self._local_data = TarLocalData(tar2info={}, tar2object={})
|
|
||||||
self._augmentation_pipeline = AugmentationPipeline(
|
|
||||||
augmentation_config=aug_file.read(), random_seed=random_seed)
|
|
||||||
|
|
||||||
self._normalizer = FeatureNormalizer(
|
|
||||||
mean_std_filepath) if mean_std_filepath else None
|
|
||||||
|
|
||||||
self._stride_ms = stride_ms
|
|
||||||
self._target_sample_rate = target_sample_rate
|
|
||||||
|
|
||||||
self._speech_featurizer = SpeechFeaturizer(
|
|
||||||
unit_type=unit_type,
|
|
||||||
vocab_filepath=vocab_filepath,
|
|
||||||
spm_model_prefix=spm_model_prefix,
|
|
||||||
specgram_type=specgram_type,
|
|
||||||
feat_dim=feat_dim,
|
|
||||||
delta_delta=delta_delta,
|
|
||||||
stride_ms=stride_ms,
|
|
||||||
window_ms=window_ms,
|
|
||||||
n_fft=n_fft,
|
|
||||||
max_freq=max_freq,
|
|
||||||
target_sample_rate=target_sample_rate,
|
|
||||||
use_dB_normalization=use_dB_normalization,
|
|
||||||
target_dB=target_dB,
|
|
||||||
dither=dither)
|
|
||||||
|
|
||||||
def _parse_tar(self, file):
|
|
||||||
"""Parse a tar file to get a tarfile object
|
|
||||||
and a map containing tarinfoes
|
|
||||||
"""
|
|
||||||
result = {}
|
|
||||||
f = tarfile.open(file)
|
|
||||||
for tarinfo in f.getmembers():
|
|
||||||
result[tarinfo.name] = tarinfo
|
|
||||||
return f, result
|
|
||||||
|
|
||||||
def _subfile_from_tar(self, file):
|
|
||||||
"""Get subfile object from tar.
|
|
||||||
|
|
||||||
It will return a subfile object from tar file
|
|
||||||
and cached tar file info for next reading request.
|
|
||||||
"""
|
|
||||||
tarpath, filename = file.split(':', 1)[1].split('#', 1)
|
|
||||||
if 'tar2info' not in self._local_data.__dict__:
|
|
||||||
self._local_data.tar2info = {}
|
|
||||||
if 'tar2object' not in self._local_data.__dict__:
|
|
||||||
self._local_data.tar2object = {}
|
|
||||||
if tarpath not in self._local_data.tar2info:
|
|
||||||
object, infoes = self._parse_tar(tarpath)
|
|
||||||
self._local_data.tar2info[tarpath] = infoes
|
|
||||||
self._local_data.tar2object[tarpath] = object
|
|
||||||
return self._local_data.tar2object[tarpath].extractfile(
|
|
||||||
self._local_data.tar2info[tarpath][filename])
|
|
||||||
|
|
||||||
def process_utterance(self, audio_file, transcript):
|
|
||||||
"""Load, augment, featurize and normalize for speech data.
|
|
||||||
|
|
||||||
:param audio_file: Filepath or file object of audio file.
|
|
||||||
:type audio_file: str | file
|
|
||||||
:param transcript: Transcription text.
|
|
||||||
:type transcript: str
|
|
||||||
:return: Tuple of audio feature tensor and data of transcription part,
|
|
||||||
where transcription part could be token ids or text.
|
|
||||||
:rtype: tuple of (2darray, list)
|
|
||||||
"""
|
|
||||||
if isinstance(audio_file, str) and audio_file.startswith('tar:'):
|
|
||||||
speech_segment = SpeechSegment.from_file(
|
|
||||||
self._subfile_from_tar(audio_file), transcript)
|
|
||||||
else:
|
|
||||||
speech_segment = SpeechSegment.from_file(audio_file, transcript)
|
|
||||||
|
|
||||||
# audio augment
|
|
||||||
self._augmentation_pipeline.transform_audio(speech_segment)
|
|
||||||
|
|
||||||
specgram, transcript_part = self._speech_featurizer.featurize(
|
|
||||||
speech_segment, self._keep_transcription_text)
|
|
||||||
if self._normalizer:
|
|
||||||
specgram = self._normalizer.apply(specgram)
|
|
||||||
|
|
||||||
# specgram augment
|
|
||||||
specgram = self._augmentation_pipeline.transform_feature(specgram)
|
|
||||||
return specgram, transcript_part
|
|
||||||
|
|
||||||
def __call__(self, batch):
|
|
||||||
"""batch examples
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch ([List]): batch is (audio, text)
|
|
||||||
audio (np.ndarray) shape (T, D)
|
|
||||||
text (List[int] or str): shape (U,)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple(audio, text, audio_lens, text_lens): batched data.
|
|
||||||
audio : (B, Tmax, D)
|
|
||||||
audio_lens: (B)
|
|
||||||
text : (B, Umax)
|
|
||||||
text_lens: (B)
|
|
||||||
"""
|
|
||||||
audios = []
|
|
||||||
audio_lens = []
|
|
||||||
texts = []
|
|
||||||
text_lens = []
|
|
||||||
utts = []
|
|
||||||
for utt, audio, text in batch:
|
|
||||||
audio, text = self.process_utterance(audio, text)
|
|
||||||
#utt
|
|
||||||
utts.append(utt)
|
|
||||||
# audio
|
|
||||||
audios.append(audio) # [T, D]
|
|
||||||
audio_lens.append(audio.shape[0])
|
|
||||||
# text
|
|
||||||
# for training, text is token ids
|
|
||||||
# else text is string, convert to unicode ord
|
|
||||||
tokens = []
|
|
||||||
if self._keep_transcription_text:
|
|
||||||
assert isinstance(text, str), (type(text), text)
|
|
||||||
tokens = [ord(t) for t in text]
|
|
||||||
else:
|
|
||||||
tokens = text # token ids
|
|
||||||
tokens = tokens if isinstance(tokens, np.ndarray) else np.array(
|
|
||||||
tokens, dtype=np.int64)
|
|
||||||
texts.append(tokens)
|
|
||||||
text_lens.append(tokens.shape[0])
|
|
||||||
|
|
||||||
#[B, T, D]
|
|
||||||
xs_pad = pad_list(audios, 0.0).astype(np.float32)
|
|
||||||
ilens = np.array(audio_lens).astype(np.int64)
|
|
||||||
ys_pad = pad_list(texts, IGNORE_ID).astype(np.int64)
|
|
||||||
olens = np.array(text_lens).astype(np.int64)
|
|
||||||
return utts, xs_pad, ilens, ys_pad, olens
|
|
||||||
|
|
||||||
@property
|
|
||||||
def manifest(self):
|
|
||||||
return self._manifest
|
|
||||||
|
|
||||||
@property
|
|
||||||
def vocab_size(self):
|
|
||||||
return self._speech_featurizer.vocab_size
|
|
||||||
|
|
||||||
@property
|
|
||||||
def vocab_list(self):
|
|
||||||
return self._speech_featurizer.vocab_list
|
|
||||||
|
|
||||||
@property
|
|
||||||
def vocab_dict(self):
|
|
||||||
return self._speech_featurizer.vocab_dict
|
|
||||||
|
|
||||||
@property
|
|
||||||
def text_feature(self):
|
|
||||||
return self._speech_featurizer.text_feature
|
|
||||||
|
|
||||||
@property
|
|
||||||
def feature_size(self):
|
|
||||||
return self._speech_featurizer.feature_size
|
|
||||||
|
|
||||||
@property
|
|
||||||
def stride_ms(self):
|
|
||||||
return self._speech_featurizer.stride_ms
|
|
||||||
@ -1,631 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import io
|
|
||||||
from collections import namedtuple
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import kaldiio
|
|
||||||
import numpy as np
|
|
||||||
from yacs.config import CfgNode
|
|
||||||
|
|
||||||
from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline
|
|
||||||
from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer
|
|
||||||
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
|
|
||||||
from deepspeech.frontend.normalizer import FeatureNormalizer
|
|
||||||
from deepspeech.frontend.speech import SpeechSegment
|
|
||||||
from deepspeech.frontend.utility import IGNORE_ID
|
|
||||||
from deepspeech.io.utility import pad_sequence
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
|
|
||||||
__all__ = ["SpeechCollator", "KaldiPrePorocessedCollator"]
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
# namedtupe need global for pickle.
|
|
||||||
TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object'])
|
|
||||||
|
|
||||||
|
|
||||||
class SpeechCollator():
|
|
||||||
@classmethod
|
|
||||||
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
|
|
||||||
default = CfgNode(
|
|
||||||
dict(
|
|
||||||
augmentation_config="",
|
|
||||||
random_seed=0,
|
|
||||||
mean_std_filepath="",
|
|
||||||
unit_type="char",
|
|
||||||
vocab_filepath="",
|
|
||||||
spm_model_prefix="",
|
|
||||||
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
|
|
||||||
feat_dim=0, # 'mfcc', 'fbank'
|
|
||||||
delta_delta=False, # 'mfcc', 'fbank'
|
|
||||||
stride_ms=10.0, # ms
|
|
||||||
window_ms=20.0, # ms
|
|
||||||
n_fft=None, # fft points
|
|
||||||
max_freq=None, # None for samplerate/2
|
|
||||||
target_sample_rate=16000, # target sample rate
|
|
||||||
use_dB_normalization=True,
|
|
||||||
target_dB=-20,
|
|
||||||
dither=1.0, # feature dither
|
|
||||||
keep_transcription_text=False))
|
|
||||||
|
|
||||||
if config is not None:
|
|
||||||
config.merge_from_other_cfg(default)
|
|
||||||
return default
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config):
|
|
||||||
"""Build a SpeechCollator object from a config.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (yacs.config.CfgNode): configs object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SpeechCollator: collator object.
|
|
||||||
"""
|
|
||||||
assert 'augmentation_config' in config.collator
|
|
||||||
assert 'keep_transcription_text' in config.collator
|
|
||||||
assert 'mean_std_filepath' in config.collator
|
|
||||||
assert 'vocab_filepath' in config.collator
|
|
||||||
assert 'specgram_type' in config.collator
|
|
||||||
assert 'n_fft' in config.collator
|
|
||||||
assert config.collator
|
|
||||||
|
|
||||||
if isinstance(config.collator.augmentation_config, (str, bytes)):
|
|
||||||
if config.collator.augmentation_config:
|
|
||||||
aug_file = io.open(
|
|
||||||
config.collator.augmentation_config,
|
|
||||||
mode='r',
|
|
||||||
encoding='utf8')
|
|
||||||
else:
|
|
||||||
aug_file = io.StringIO(initial_value='{}', newline='')
|
|
||||||
else:
|
|
||||||
aug_file = config.collator.augmentation_config
|
|
||||||
assert isinstance(aug_file, io.StringIO)
|
|
||||||
|
|
||||||
speech_collator = cls(
|
|
||||||
aug_file=aug_file,
|
|
||||||
random_seed=0,
|
|
||||||
mean_std_filepath=config.collator.mean_std_filepath,
|
|
||||||
unit_type=config.collator.unit_type,
|
|
||||||
vocab_filepath=config.collator.vocab_filepath,
|
|
||||||
spm_model_prefix=config.collator.spm_model_prefix,
|
|
||||||
specgram_type=config.collator.specgram_type,
|
|
||||||
feat_dim=config.collator.feat_dim,
|
|
||||||
delta_delta=config.collator.delta_delta,
|
|
||||||
stride_ms=config.collator.stride_ms,
|
|
||||||
window_ms=config.collator.window_ms,
|
|
||||||
n_fft=config.collator.n_fft,
|
|
||||||
max_freq=config.collator.max_freq,
|
|
||||||
target_sample_rate=config.collator.target_sample_rate,
|
|
||||||
use_dB_normalization=config.collator.use_dB_normalization,
|
|
||||||
target_dB=config.collator.target_dB,
|
|
||||||
dither=config.collator.dither,
|
|
||||||
keep_transcription_text=config.collator.keep_transcription_text)
|
|
||||||
return speech_collator
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
aug_file,
|
|
||||||
mean_std_filepath,
|
|
||||||
vocab_filepath,
|
|
||||||
spm_model_prefix,
|
|
||||||
random_seed=0,
|
|
||||||
unit_type="char",
|
|
||||||
specgram_type='linear', # 'linear', 'mfcc', 'fbank'
|
|
||||||
feat_dim=0, # 'mfcc', 'fbank'
|
|
||||||
delta_delta=False, # 'mfcc', 'fbank'
|
|
||||||
stride_ms=10.0, # ms
|
|
||||||
window_ms=20.0, # ms
|
|
||||||
n_fft=None, # fft points
|
|
||||||
max_freq=None, # None for samplerate/2
|
|
||||||
target_sample_rate=16000, # target sample rate
|
|
||||||
use_dB_normalization=True,
|
|
||||||
target_dB=-20,
|
|
||||||
dither=1.0,
|
|
||||||
keep_transcription_text=True):
|
|
||||||
"""SpeechCollator Collator
|
|
||||||
|
|
||||||
Args:
|
|
||||||
unit_type(str): token unit type, e.g. char, word, spm
|
|
||||||
vocab_filepath (str): vocab file path.
|
|
||||||
mean_std_filepath (str): mean and std file path, which suffix is *.npy
|
|
||||||
spm_model_prefix (str): spm model prefix, need if `unit_type` is spm.
|
|
||||||
augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
|
|
||||||
stride_ms (float, optional): stride size in ms. Defaults to 10.0.
|
|
||||||
window_ms (float, optional): window size in ms. Defaults to 20.0.
|
|
||||||
n_fft (int, optional): fft points for rfft. Defaults to None.
|
|
||||||
max_freq (int, optional): max cut freq. Defaults to None.
|
|
||||||
target_sample_rate (int, optional): target sample rate which used for training. Defaults to 16000.
|
|
||||||
specgram_type (str, optional): 'linear', 'mfcc' or 'fbank'. Defaults to 'linear'.
|
|
||||||
feat_dim (int, optional): audio feature dim, using by 'mfcc' or 'fbank'. Defaults to None.
|
|
||||||
delta_delta (bool, optional): audio feature with delta-delta, using by 'fbank' or 'mfcc'. Defaults to False.
|
|
||||||
use_dB_normalization (bool, optional): do dB normalization. Defaults to True.
|
|
||||||
target_dB (int, optional): target dB. Defaults to -20.
|
|
||||||
random_seed (int, optional): for random generator. Defaults to 0.
|
|
||||||
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
|
|
||||||
if ``keep_transcription_text`` is False, text is token ids else is raw string.
|
|
||||||
|
|
||||||
Do augmentations
|
|
||||||
Padding audio features with zeros to make them have the same shape (or
|
|
||||||
a user-defined shape) within one batch.
|
|
||||||
"""
|
|
||||||
self._keep_transcription_text = keep_transcription_text
|
|
||||||
|
|
||||||
self._local_data = TarLocalData(tar2info={}, tar2object={})
|
|
||||||
self._augmentation_pipeline = AugmentationPipeline(
|
|
||||||
augmentation_config=aug_file.read(), random_seed=random_seed)
|
|
||||||
|
|
||||||
self._normalizer = FeatureNormalizer(
|
|
||||||
mean_std_filepath) if mean_std_filepath else None
|
|
||||||
|
|
||||||
self._stride_ms = stride_ms
|
|
||||||
self._target_sample_rate = target_sample_rate
|
|
||||||
|
|
||||||
self._speech_featurizer = SpeechFeaturizer(
|
|
||||||
unit_type=unit_type,
|
|
||||||
vocab_filepath=vocab_filepath,
|
|
||||||
spm_model_prefix=spm_model_prefix,
|
|
||||||
specgram_type=specgram_type,
|
|
||||||
feat_dim=feat_dim,
|
|
||||||
delta_delta=delta_delta,
|
|
||||||
stride_ms=stride_ms,
|
|
||||||
window_ms=window_ms,
|
|
||||||
n_fft=n_fft,
|
|
||||||
max_freq=max_freq,
|
|
||||||
target_sample_rate=target_sample_rate,
|
|
||||||
use_dB_normalization=use_dB_normalization,
|
|
||||||
target_dB=target_dB,
|
|
||||||
dither=dither)
|
|
||||||
|
|
||||||
def _parse_tar(self, file):
|
|
||||||
"""Parse a tar file to get a tarfile object
|
|
||||||
and a map containing tarinfoes
|
|
||||||
"""
|
|
||||||
result = {}
|
|
||||||
f = tarfile.open(file)
|
|
||||||
for tarinfo in f.getmembers():
|
|
||||||
result[tarinfo.name] = tarinfo
|
|
||||||
return f, result
|
|
||||||
|
|
||||||
def _subfile_from_tar(self, file):
|
|
||||||
"""Get subfile object from tar.
|
|
||||||
|
|
||||||
It will return a subfile object from tar file
|
|
||||||
and cached tar file info for next reading request.
|
|
||||||
"""
|
|
||||||
tarpath, filename = file.split(':', 1)[1].split('#', 1)
|
|
||||||
if 'tar2info' not in self._local_data.__dict__:
|
|
||||||
self._local_data.tar2info = {}
|
|
||||||
if 'tar2object' not in self._local_data.__dict__:
|
|
||||||
self._local_data.tar2object = {}
|
|
||||||
if tarpath not in self._local_data.tar2info:
|
|
||||||
object, infoes = self._parse_tar(tarpath)
|
|
||||||
self._local_data.tar2info[tarpath] = infoes
|
|
||||||
self._local_data.tar2object[tarpath] = object
|
|
||||||
return self._local_data.tar2object[tarpath].extractfile(
|
|
||||||
self._local_data.tar2info[tarpath][filename])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def manifest(self):
|
|
||||||
return self._manifest
|
|
||||||
|
|
||||||
@property
|
|
||||||
def vocab_size(self):
|
|
||||||
return self._speech_featurizer.vocab_size
|
|
||||||
|
|
||||||
@property
|
|
||||||
def vocab_list(self):
|
|
||||||
return self._speech_featurizer.vocab_list
|
|
||||||
|
|
||||||
@property
|
|
||||||
def vocab_dict(self):
|
|
||||||
return self._speech_featurizer.vocab_dict
|
|
||||||
|
|
||||||
@property
|
|
||||||
def text_feature(self):
|
|
||||||
return self._speech_featurizer.text_feature
|
|
||||||
|
|
||||||
@property
|
|
||||||
def feature_size(self):
|
|
||||||
return self._speech_featurizer.feature_size
|
|
||||||
|
|
||||||
@property
|
|
||||||
def stride_ms(self):
|
|
||||||
return self._speech_featurizer.stride_ms
|
|
||||||
|
|
||||||
def process_utterance(self, audio_file, translation):
|
|
||||||
"""Load, augment, featurize and normalize for speech data.
|
|
||||||
|
|
||||||
:param audio_file: Filepath or file object of audio file.
|
|
||||||
:type audio_file: str | file
|
|
||||||
:param translation: translation text.
|
|
||||||
:type translation: str
|
|
||||||
:return: Tuple of audio feature tensor and data of translation part,
|
|
||||||
where translation part could be token ids or text.
|
|
||||||
:rtype: tuple of (2darray, list)
|
|
||||||
"""
|
|
||||||
if isinstance(audio_file, str) and audio_file.startswith('tar:'):
|
|
||||||
speech_segment = SpeechSegment.from_file(
|
|
||||||
self._subfile_from_tar(audio_file), translation)
|
|
||||||
else:
|
|
||||||
speech_segment = SpeechSegment.from_file(audio_file, translation)
|
|
||||||
|
|
||||||
# audio augment
|
|
||||||
self._augmentation_pipeline.transform_audio(speech_segment)
|
|
||||||
|
|
||||||
specgram, translation_part = self._speech_featurizer.featurize(
|
|
||||||
speech_segment, self._keep_transcription_text)
|
|
||||||
if self._normalizer:
|
|
||||||
specgram = self._normalizer.apply(specgram)
|
|
||||||
|
|
||||||
# specgram augment
|
|
||||||
specgram = self._augmentation_pipeline.transform_feature(specgram)
|
|
||||||
return specgram, translation_part
|
|
||||||
|
|
||||||
def __call__(self, batch):
|
|
||||||
"""batch examples
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch ([List]): batch is (audio, text)
|
|
||||||
audio (np.ndarray) shape (T, D)
|
|
||||||
text (List[int] or str): shape (U,)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple(audio, text, audio_lens, text_lens): batched data.
|
|
||||||
audio : (B, Tmax, D)
|
|
||||||
audio_lens: (B)
|
|
||||||
text : (B, Umax)
|
|
||||||
text_lens: (B)
|
|
||||||
"""
|
|
||||||
audios = []
|
|
||||||
audio_lens = []
|
|
||||||
texts = []
|
|
||||||
text_lens = []
|
|
||||||
utts = []
|
|
||||||
for utt, audio, text in batch:
|
|
||||||
audio, text = self.process_utterance(audio, text)
|
|
||||||
#utt
|
|
||||||
utts.append(utt)
|
|
||||||
# audio
|
|
||||||
audios.append(audio) # [T, D]
|
|
||||||
audio_lens.append(audio.shape[0])
|
|
||||||
# text
|
|
||||||
# for training, text is token ids
|
|
||||||
# else text is string, convert to unicode ord
|
|
||||||
tokens = []
|
|
||||||
if self._keep_transcription_text:
|
|
||||||
assert isinstance(text, str), (type(text), text)
|
|
||||||
tokens = [ord(t) for t in text]
|
|
||||||
else:
|
|
||||||
tokens = text # token ids
|
|
||||||
tokens = tokens if isinstance(tokens, np.ndarray) else np.array(
|
|
||||||
tokens, dtype=np.int64)
|
|
||||||
texts.append(tokens)
|
|
||||||
text_lens.append(tokens.shape[0])
|
|
||||||
|
|
||||||
padded_audios = pad_sequence(
|
|
||||||
audios, padding_value=0.0).astype(np.float32) #[B, T, D]
|
|
||||||
audio_lens = np.array(audio_lens).astype(np.int64)
|
|
||||||
padded_texts = pad_sequence(
|
|
||||||
texts, padding_value=IGNORE_ID).astype(np.int64)
|
|
||||||
text_lens = np.array(text_lens).astype(np.int64)
|
|
||||||
return utts, padded_audios, audio_lens, padded_texts, text_lens
|
|
||||||
|
|
||||||
|
|
||||||
class TripletSpeechCollator(SpeechCollator):
|
|
||||||
def process_utterance(self, audio_file, translation, transcript):
|
|
||||||
"""Load, augment, featurize and normalize for speech data.
|
|
||||||
|
|
||||||
:param audio_file: Filepath or file object of audio file.
|
|
||||||
:type audio_file: str | file
|
|
||||||
:param translation: translation text.
|
|
||||||
:type translation: str
|
|
||||||
:return: Tuple of audio feature tensor and data of translation part,
|
|
||||||
where translation part could be token ids or text.
|
|
||||||
:rtype: tuple of (2darray, list)
|
|
||||||
"""
|
|
||||||
if isinstance(audio_file, str) and audio_file.startswith('tar:'):
|
|
||||||
speech_segment = SpeechSegment.from_file(
|
|
||||||
self._subfile_from_tar(audio_file), translation)
|
|
||||||
else:
|
|
||||||
speech_segment = SpeechSegment.from_file(audio_file, translation)
|
|
||||||
|
|
||||||
# audio augment
|
|
||||||
self._augmentation_pipeline.transform_audio(speech_segment)
|
|
||||||
|
|
||||||
specgram, translation_part = self._speech_featurizer.featurize(
|
|
||||||
speech_segment, self._keep_transcription_text)
|
|
||||||
transcript_part = self._speech_featurizer._text_featurizer.featurize(
|
|
||||||
transcript)
|
|
||||||
if self._normalizer:
|
|
||||||
specgram = self._normalizer.apply(specgram)
|
|
||||||
|
|
||||||
# specgram augment
|
|
||||||
specgram = self._augmentation_pipeline.transform_feature(specgram)
|
|
||||||
return specgram, translation_part, transcript_part
|
|
||||||
|
|
||||||
def __call__(self, batch):
|
|
||||||
"""batch examples
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch ([List]): batch is (audio, text)
|
|
||||||
audio (np.ndarray) shape (T, D)
|
|
||||||
text (List[int] or str): shape (U,)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple(audio, text, audio_lens, text_lens): batched data.
|
|
||||||
audio : (B, Tmax, D)
|
|
||||||
audio_lens: (B)
|
|
||||||
text : (B, Umax)
|
|
||||||
text_lens: (B)
|
|
||||||
"""
|
|
||||||
audios = []
|
|
||||||
audio_lens = []
|
|
||||||
translation_text = []
|
|
||||||
translation_text_lens = []
|
|
||||||
transcription_text = []
|
|
||||||
transcription_text_lens = []
|
|
||||||
|
|
||||||
utts = []
|
|
||||||
for utt, audio, translation, transcription in batch:
|
|
||||||
audio, translation, transcription = self.process_utterance(
|
|
||||||
audio, translation, transcription)
|
|
||||||
#utt
|
|
||||||
utts.append(utt)
|
|
||||||
# audio
|
|
||||||
audios.append(audio) # [T, D]
|
|
||||||
audio_lens.append(audio.shape[0])
|
|
||||||
# text
|
|
||||||
# for training, text is token ids
|
|
||||||
# else text is string, convert to unicode ord
|
|
||||||
tokens = [[], []]
|
|
||||||
for idx, text in enumerate([translation, transcription]):
|
|
||||||
if self._keep_transcription_text:
|
|
||||||
assert isinstance(text, str), (type(text), text)
|
|
||||||
tokens[idx] = [ord(t) for t in text]
|
|
||||||
else:
|
|
||||||
tokens[idx] = text # token ids
|
|
||||||
tokens[idx] = tokens[idx] if isinstance(
|
|
||||||
tokens[idx], np.ndarray) else np.array(
|
|
||||||
tokens[idx], dtype=np.int64)
|
|
||||||
translation_text.append(tokens[0])
|
|
||||||
translation_text_lens.append(tokens[0].shape[0])
|
|
||||||
transcription_text.append(tokens[1])
|
|
||||||
transcription_text_lens.append(tokens[1].shape[0])
|
|
||||||
|
|
||||||
padded_audios = pad_sequence(
|
|
||||||
audios, padding_value=0.0).astype(np.float32) #[B, T, D]
|
|
||||||
audio_lens = np.array(audio_lens).astype(np.int64)
|
|
||||||
padded_translation = pad_sequence(
|
|
||||||
translation_text, padding_value=IGNORE_ID).astype(np.int64)
|
|
||||||
translation_lens = np.array(translation_text_lens).astype(np.int64)
|
|
||||||
padded_transcription = pad_sequence(
|
|
||||||
transcription_text, padding_value=IGNORE_ID).astype(np.int64)
|
|
||||||
transcription_lens = np.array(transcription_text_lens).astype(np.int64)
|
|
||||||
return utts, padded_audios, audio_lens, (
|
|
||||||
padded_translation, padded_transcription), (translation_lens,
|
|
||||||
transcription_lens)
|
|
||||||
|
|
||||||
|
|
||||||
class KaldiPrePorocessedCollator(SpeechCollator):
|
|
||||||
@classmethod
|
|
||||||
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
|
|
||||||
default = CfgNode(
|
|
||||||
dict(
|
|
||||||
augmentation_config="",
|
|
||||||
random_seed=0,
|
|
||||||
unit_type="char",
|
|
||||||
vocab_filepath="",
|
|
||||||
spm_model_prefix="",
|
|
||||||
feat_dim=0,
|
|
||||||
stride_ms=10.0,
|
|
||||||
keep_transcription_text=False))
|
|
||||||
|
|
||||||
if config is not None:
|
|
||||||
config.merge_from_other_cfg(default)
|
|
||||||
return default
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config):
|
|
||||||
"""Build a SpeechCollator object from a config.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (yacs.config.CfgNode): configs object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SpeechCollator: collator object.
|
|
||||||
"""
|
|
||||||
assert 'augmentation_config' in config.collator
|
|
||||||
assert 'keep_transcription_text' in config.collator
|
|
||||||
assert 'vocab_filepath' in config.collator
|
|
||||||
assert config.collator
|
|
||||||
|
|
||||||
if isinstance(config.collator.augmentation_config, (str, bytes)):
|
|
||||||
if config.collator.augmentation_config:
|
|
||||||
aug_file = io.open(
|
|
||||||
config.collator.augmentation_config,
|
|
||||||
mode='r',
|
|
||||||
encoding='utf8')
|
|
||||||
else:
|
|
||||||
aug_file = io.StringIO(initial_value='{}', newline='')
|
|
||||||
else:
|
|
||||||
aug_file = config.collator.augmentation_config
|
|
||||||
assert isinstance(aug_file, io.StringIO)
|
|
||||||
|
|
||||||
speech_collator = cls(
|
|
||||||
aug_file=aug_file,
|
|
||||||
random_seed=0,
|
|
||||||
unit_type=config.collator.unit_type,
|
|
||||||
vocab_filepath=config.collator.vocab_filepath,
|
|
||||||
spm_model_prefix=config.collator.spm_model_prefix,
|
|
||||||
feat_dim=config.collator.feat_dim,
|
|
||||||
stride_ms=config.collator.stride_ms,
|
|
||||||
keep_transcription_text=config.collator.keep_transcription_text)
|
|
||||||
return speech_collator
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
aug_file,
|
|
||||||
vocab_filepath,
|
|
||||||
spm_model_prefix,
|
|
||||||
random_seed=0,
|
|
||||||
unit_type="char",
|
|
||||||
feat_dim=0,
|
|
||||||
stride_ms=10.0,
|
|
||||||
keep_transcription_text=True):
|
|
||||||
"""SpeechCollator Collator
|
|
||||||
|
|
||||||
Args:
|
|
||||||
unit_type(str): token unit type, e.g. char, word, spm
|
|
||||||
vocab_filepath (str): vocab file path.
|
|
||||||
spm_model_prefix (str): spm model prefix, need if `unit_type` is spm.
|
|
||||||
augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
|
|
||||||
random_seed (int, optional): for random generator. Defaults to 0.
|
|
||||||
keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False.
|
|
||||||
if ``keep_transcription_text`` is False, text is token ids else is raw string.
|
|
||||||
|
|
||||||
Do augmentations
|
|
||||||
Padding audio features with zeros to make them have the same shape (or
|
|
||||||
a user-defined shape) within one batch.
|
|
||||||
"""
|
|
||||||
self._keep_transcription_text = keep_transcription_text
|
|
||||||
self._feat_dim = feat_dim
|
|
||||||
self._stride_ms = stride_ms
|
|
||||||
|
|
||||||
self._local_data = TarLocalData(tar2info={}, tar2object={})
|
|
||||||
self._augmentation_pipeline = AugmentationPipeline(
|
|
||||||
augmentation_config=aug_file.read(), random_seed=random_seed)
|
|
||||||
|
|
||||||
self._text_featurizer = TextFeaturizer(unit_type, vocab_filepath,
|
|
||||||
spm_model_prefix)
|
|
||||||
|
|
||||||
def process_utterance(self, audio_file, translation):
|
|
||||||
"""Load, augment, featurize and normalize for speech data.
|
|
||||||
|
|
||||||
:param audio_file: Filepath or file object of kaldi processed feature.
|
|
||||||
:type audio_file: str | file
|
|
||||||
:param translation: Translation text.
|
|
||||||
:type translation: str
|
|
||||||
:return: Tuple of audio feature tensor and data of translation part,
|
|
||||||
where translation part could be token ids or text.
|
|
||||||
:rtype: tuple of (2darray, list)
|
|
||||||
"""
|
|
||||||
specgram = kaldiio.load_mat(audio_file)
|
|
||||||
assert specgram.shape[
|
|
||||||
1] == self._feat_dim, 'expect feat dim {}, but got {}'.format(
|
|
||||||
self._feat_dim, specgram.shape[1])
|
|
||||||
|
|
||||||
# specgram augment
|
|
||||||
specgram = self._augmentation_pipeline.transform_feature(specgram)
|
|
||||||
|
|
||||||
if self._keep_transcription_text:
|
|
||||||
return specgram, translation
|
|
||||||
else:
|
|
||||||
text_ids = self._text_featurizer.featurize(translation)
|
|
||||||
return specgram, text_ids
|
|
||||||
|
|
||||||
|
|
||||||
class TripletKaldiPrePorocessedCollator(KaldiPrePorocessedCollator):
|
|
||||||
def process_utterance(self, audio_file, translation, transcript):
|
|
||||||
"""Load, augment, featurize and normalize for speech data.
|
|
||||||
|
|
||||||
:param audio_file: Filepath or file object of kali processed feature.
|
|
||||||
:type audio_file: str | file
|
|
||||||
:param translation: Translation text.
|
|
||||||
:type translation: str
|
|
||||||
:param transcript: Transcription text.
|
|
||||||
:type transcript: str
|
|
||||||
:return: Tuple of audio feature tensor and data of translation and transcription parts,
|
|
||||||
where translation and transcription parts could be token ids or text.
|
|
||||||
:rtype: tuple of (2darray, (list, list))
|
|
||||||
"""
|
|
||||||
specgram = kaldiio.load_mat(audio_file)
|
|
||||||
assert specgram.shape[
|
|
||||||
1] == self._feat_dim, 'expect feat dim {}, but got {}'.format(
|
|
||||||
self._feat_dim, specgram.shape[1])
|
|
||||||
|
|
||||||
# specgram augment
|
|
||||||
specgram = self._augmentation_pipeline.transform_feature(specgram)
|
|
||||||
|
|
||||||
if self._keep_transcription_text:
|
|
||||||
return specgram, translation, transcript
|
|
||||||
else:
|
|
||||||
translation_text_ids = self._text_featurizer.featurize(translation)
|
|
||||||
transcript_text_ids = self._text_featurizer.featurize(transcript)
|
|
||||||
return specgram, translation_text_ids, transcript_text_ids
|
|
||||||
|
|
||||||
def __call__(self, batch):
|
|
||||||
"""batch examples
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch ([List]): batch is (audio, text)
|
|
||||||
audio (np.ndarray) shape (T, D)
|
|
||||||
translation (List[int] or str): shape (U,)
|
|
||||||
transcription (List[int] or str): shape (V,)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple(audio, text, audio_lens, text_lens): batched data.
|
|
||||||
audio : (B, Tmax, D)
|
|
||||||
audio_lens: (B)
|
|
||||||
translation_text : (B, Umax)
|
|
||||||
translation_text_lens: (B)
|
|
||||||
transcription_text : (B, Vmax)
|
|
||||||
transcription_text_lens: (B)
|
|
||||||
"""
|
|
||||||
audios = []
|
|
||||||
audio_lens = []
|
|
||||||
translation_text = []
|
|
||||||
translation_text_lens = []
|
|
||||||
transcription_text = []
|
|
||||||
transcription_text_lens = []
|
|
||||||
|
|
||||||
utts = []
|
|
||||||
for utt, audio, translation, transcription in batch:
|
|
||||||
audio, translation, transcription = self.process_utterance(
|
|
||||||
audio, translation, transcription)
|
|
||||||
#utt
|
|
||||||
utts.append(utt)
|
|
||||||
# audio
|
|
||||||
audios.append(audio) # [T, D]
|
|
||||||
audio_lens.append(audio.shape[0])
|
|
||||||
# text
|
|
||||||
# for training, text is token ids
|
|
||||||
# else text is string, convert to unicode ord
|
|
||||||
tokens = [[], []]
|
|
||||||
for idx, text in enumerate([translation, transcription]):
|
|
||||||
if self._keep_transcription_text:
|
|
||||||
assert isinstance(text, str), (type(text), text)
|
|
||||||
tokens[idx] = [ord(t) for t in text]
|
|
||||||
else:
|
|
||||||
tokens[idx] = text # token ids
|
|
||||||
tokens[idx] = tokens[idx] if isinstance(
|
|
||||||
tokens[idx], np.ndarray) else np.array(
|
|
||||||
tokens[idx], dtype=np.int64)
|
|
||||||
translation_text.append(tokens[0])
|
|
||||||
translation_text_lens.append(tokens[0].shape[0])
|
|
||||||
transcription_text.append(tokens[1])
|
|
||||||
transcription_text_lens.append(tokens[1].shape[0])
|
|
||||||
|
|
||||||
padded_audios = pad_sequence(
|
|
||||||
audios, padding_value=0.0).astype(np.float32) #[B, T, D]
|
|
||||||
audio_lens = np.array(audio_lens).astype(np.int64)
|
|
||||||
padded_translation = pad_sequence(
|
|
||||||
translation_text, padding_value=IGNORE_ID).astype(np.int64)
|
|
||||||
translation_lens = np.array(translation_text_lens).astype(np.int64)
|
|
||||||
padded_transcription = pad_sequence(
|
|
||||||
transcription_text, padding_value=IGNORE_ID).astype(np.int64)
|
|
||||||
transcription_lens = np.array(transcription_text_lens).astype(np.int64)
|
|
||||||
return utts, padded_audios, audio_lens, (
|
|
||||||
padded_translation, padded_transcription), (translation_lens,
|
|
||||||
transcription_lens)
|
|
||||||
@ -1,81 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from deepspeech.io.utility import pad_list
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
|
|
||||||
__all__ = ["CustomConverter"]
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
|
|
||||||
class CustomConverter():
|
|
||||||
"""Custom batch converter.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
subsampling_factor (int): The subsampling factor.
|
|
||||||
dtype (np.dtype): Data type to convert.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, subsampling_factor=1, dtype=np.float32):
|
|
||||||
"""Construct a CustomConverter object."""
|
|
||||||
self.subsampling_factor = subsampling_factor
|
|
||||||
self.ignore_id = -1
|
|
||||||
self.dtype = dtype
|
|
||||||
|
|
||||||
def __call__(self, batch):
|
|
||||||
"""Transform a batch and send it to a device.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch (list): The batch to transform.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple(np.ndarray, nn.ndarray, nn.ndarray)
|
|
||||||
|
|
||||||
"""
|
|
||||||
# batch should be located in list
|
|
||||||
assert len(batch) == 1
|
|
||||||
(xs, ys), utts = batch[0]
|
|
||||||
assert xs[0] is not None, "please check Reader and Augmentation impl."
|
|
||||||
|
|
||||||
# perform subsampling
|
|
||||||
if self.subsampling_factor > 1:
|
|
||||||
xs = [x[::self.subsampling_factor, :] for x in xs]
|
|
||||||
|
|
||||||
# get batch of lengths of input sequences
|
|
||||||
ilens = np.array([x.shape[0] for x in xs])
|
|
||||||
|
|
||||||
# perform padding and convert to tensor
|
|
||||||
# currently only support real number
|
|
||||||
if xs[0].dtype.kind == "c":
|
|
||||||
xs_pad_real = pad_list([x.real for x in xs], 0).astype(self.dtype)
|
|
||||||
xs_pad_imag = pad_list([x.imag for x in xs], 0).astype(self.dtype)
|
|
||||||
# Note(kamo):
|
|
||||||
# {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E.
|
|
||||||
# Don't create ComplexTensor and give it E2E here
|
|
||||||
# because torch.nn.DataParellel can't handle it.
|
|
||||||
xs_pad = {"real": xs_pad_real, "imag": xs_pad_imag}
|
|
||||||
else:
|
|
||||||
xs_pad = pad_list(xs, 0).astype(self.dtype)
|
|
||||||
|
|
||||||
# NOTE: this is for multi-output (e.g., speech translation)
|
|
||||||
ys_pad = pad_list(
|
|
||||||
[np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys],
|
|
||||||
self.ignore_id)
|
|
||||||
|
|
||||||
olens = np.array(
|
|
||||||
[y[0].shape[0] if isinstance(y, tuple) else y.shape[0] for y in ys])
|
|
||||||
return utts, xs_pad, ilens, ys_pad, olens
|
|
||||||
@ -1,170 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
from typing import Any
|
|
||||||
from typing import Dict
|
|
||||||
from typing import List
|
|
||||||
from typing import Text
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from paddle.io import DataLoader
|
|
||||||
|
|
||||||
from deepspeech.frontend.utility import read_manifest
|
|
||||||
from deepspeech.io.batchfy import make_batchset
|
|
||||||
from deepspeech.io.converter import CustomConverter
|
|
||||||
from deepspeech.io.dataset import TransformDataset
|
|
||||||
from deepspeech.io.reader import LoadInputsAndTargets
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
|
|
||||||
__all__ = ["BatchDataLoader"]
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
|
|
||||||
def feat_dim_and_vocab_size(data_json: List[Dict[Text, Any]],
|
|
||||||
mode: Text="asr",
|
|
||||||
iaxis=0,
|
|
||||||
oaxis=0):
|
|
||||||
if mode == 'asr':
|
|
||||||
feat_dim = data_json[0]['input'][oaxis]['shape'][1]
|
|
||||||
vocab_size = data_json[0]['output'][oaxis]['shape'][1]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"{mode} mode not support!")
|
|
||||||
return feat_dim, vocab_size
|
|
||||||
|
|
||||||
|
|
||||||
def batch_collate(x):
|
|
||||||
"""de-tuple.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (List[Tuple]): [(utts, xs, ilens, ys, olens)]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple: (utts, xs, ilens, ys, olens)
|
|
||||||
"""
|
|
||||||
return x[0]
|
|
||||||
|
|
||||||
|
|
||||||
class BatchDataLoader():
|
|
||||||
def __init__(self,
|
|
||||||
json_file: str,
|
|
||||||
train_mode: bool,
|
|
||||||
sortagrad: bool=False,
|
|
||||||
batch_size: int=0,
|
|
||||||
maxlen_in: float=float('inf'),
|
|
||||||
maxlen_out: float=float('inf'),
|
|
||||||
minibatches: int=0,
|
|
||||||
mini_batch_size: int=1,
|
|
||||||
batch_count: str='auto',
|
|
||||||
batch_bins: int=0,
|
|
||||||
batch_frames_in: int=0,
|
|
||||||
batch_frames_out: int=0,
|
|
||||||
batch_frames_inout: int=0,
|
|
||||||
preprocess_conf=None,
|
|
||||||
n_iter_processes: int=1,
|
|
||||||
subsampling_factor: int=1,
|
|
||||||
num_encs: int=1):
|
|
||||||
self.json_file = json_file
|
|
||||||
self.train_mode = train_mode
|
|
||||||
self.use_sortagrad = sortagrad == -1 or sortagrad > 0
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.maxlen_in = maxlen_in
|
|
||||||
self.maxlen_out = maxlen_out
|
|
||||||
self.batch_count = batch_count
|
|
||||||
self.batch_bins = batch_bins
|
|
||||||
self.batch_frames_in = batch_frames_in
|
|
||||||
self.batch_frames_out = batch_frames_out
|
|
||||||
self.batch_frames_inout = batch_frames_inout
|
|
||||||
self.subsampling_factor = subsampling_factor
|
|
||||||
self.num_encs = num_encs
|
|
||||||
self.preprocess_conf = preprocess_conf
|
|
||||||
self.n_iter_processes = n_iter_processes
|
|
||||||
|
|
||||||
# read json data
|
|
||||||
self.data_json = read_manifest(json_file)
|
|
||||||
self.feat_dim, self.vocab_size = feat_dim_and_vocab_size(
|
|
||||||
self.data_json, mode='asr')
|
|
||||||
|
|
||||||
# make minibatch list (variable length)
|
|
||||||
self.minibaches = make_batchset(
|
|
||||||
self.data_json,
|
|
||||||
batch_size,
|
|
||||||
maxlen_in,
|
|
||||||
maxlen_out,
|
|
||||||
minibatches, # for debug
|
|
||||||
min_batch_size=mini_batch_size,
|
|
||||||
shortest_first=self.use_sortagrad,
|
|
||||||
count=batch_count,
|
|
||||||
batch_bins=batch_bins,
|
|
||||||
batch_frames_in=batch_frames_in,
|
|
||||||
batch_frames_out=batch_frames_out,
|
|
||||||
batch_frames_inout=batch_frames_inout,
|
|
||||||
iaxis=0,
|
|
||||||
oaxis=0, )
|
|
||||||
|
|
||||||
# data reader
|
|
||||||
self.reader = LoadInputsAndTargets(
|
|
||||||
mode="asr",
|
|
||||||
load_output=True,
|
|
||||||
preprocess_conf=preprocess_conf,
|
|
||||||
preprocess_args={"train":
|
|
||||||
train_mode}, # Switch the mode of preprocessing
|
|
||||||
)
|
|
||||||
|
|
||||||
# Setup a converter
|
|
||||||
if num_encs == 1:
|
|
||||||
self.converter = CustomConverter(
|
|
||||||
subsampling_factor=subsampling_factor, dtype=np.float32)
|
|
||||||
else:
|
|
||||||
assert NotImplementedError("not impl CustomConverterMulEnc.")
|
|
||||||
|
|
||||||
# hack to make batchsize argument as 1
|
|
||||||
# actual bathsize is included in a list
|
|
||||||
# default collate function converts numpy array to pytorch tensor
|
|
||||||
# we used an empty collate function instead which returns list
|
|
||||||
self.dataset = TransformDataset(self.minibaches, self.converter,
|
|
||||||
self.reader)
|
|
||||||
|
|
||||||
self.dataloader = DataLoader(
|
|
||||||
dataset=self.dataset,
|
|
||||||
batch_size=1,
|
|
||||||
shuffle=not self.use_sortagrad if self.train_mode else False,
|
|
||||||
collate_fn=batch_collate,
|
|
||||||
num_workers=self.n_iter_processes, )
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
echo = f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> "
|
|
||||||
echo += f"train_mode: {self.train_mode}, "
|
|
||||||
echo += f"sortagrad: {self.use_sortagrad}, "
|
|
||||||
echo += f"batch_size: {self.batch_size}, "
|
|
||||||
echo += f"maxlen_in: {self.maxlen_in}, "
|
|
||||||
echo += f"maxlen_out: {self.maxlen_out}, "
|
|
||||||
echo += f"batch_count: {self.batch_count}, "
|
|
||||||
echo += f"batch_bins: {self.batch_bins}, "
|
|
||||||
echo += f"batch_frames_in: {self.batch_frames_in}, "
|
|
||||||
echo += f"batch_frames_out: {self.batch_frames_out}, "
|
|
||||||
echo += f"batch_frames_inout: {self.batch_frames_inout}, "
|
|
||||||
echo += f"subsampling_factor: {self.subsampling_factor}, "
|
|
||||||
echo += f"num_encs: {self.num_encs}, "
|
|
||||||
echo += f"num_workers: {self.n_iter_processes}, "
|
|
||||||
echo += f"file: {self.json_file}"
|
|
||||||
return echo
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.dataloader)
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return self.dataloader.__iter__()
|
|
||||||
|
|
||||||
def __call__(self):
|
|
||||||
return self.__iter__()
|
|
||||||
@ -1,149 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from paddle.io import Dataset
|
|
||||||
from yacs.config import CfgNode
|
|
||||||
|
|
||||||
from deepspeech.frontend.utility import read_manifest
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
|
|
||||||
__all__ = ["ManifestDataset", "TripletManifestDataset", "TransformDataset"]
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
|
|
||||||
class ManifestDataset(Dataset):
|
|
||||||
@classmethod
|
|
||||||
def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
|
|
||||||
default = CfgNode(
|
|
||||||
dict(
|
|
||||||
manifest="",
|
|
||||||
max_input_len=27.0,
|
|
||||||
min_input_len=0.0,
|
|
||||||
max_output_len=float('inf'),
|
|
||||||
min_output_len=0.0,
|
|
||||||
max_output_input_ratio=float('inf'),
|
|
||||||
min_output_input_ratio=0.0, ))
|
|
||||||
|
|
||||||
if config is not None:
|
|
||||||
config.merge_from_other_cfg(default)
|
|
||||||
return default
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config):
|
|
||||||
"""Build a ManifestDataset object from a config.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (yacs.config.CfgNode): configs object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ManifestDataset: dataet object.
|
|
||||||
"""
|
|
||||||
assert 'manifest' in config.data
|
|
||||||
assert config.data.manifest
|
|
||||||
|
|
||||||
dataset = cls(
|
|
||||||
manifest_path=config.data.manifest,
|
|
||||||
max_input_len=config.data.max_input_len,
|
|
||||||
min_input_len=config.data.min_input_len,
|
|
||||||
max_output_len=config.data.max_output_len,
|
|
||||||
min_output_len=config.data.min_output_len,
|
|
||||||
max_output_input_ratio=config.data.max_output_input_ratio,
|
|
||||||
min_output_input_ratio=config.data.min_output_input_ratio, )
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
manifest_path,
|
|
||||||
max_input_len=float('inf'),
|
|
||||||
min_input_len=0.0,
|
|
||||||
max_output_len=float('inf'),
|
|
||||||
min_output_len=0.0,
|
|
||||||
max_output_input_ratio=float('inf'),
|
|
||||||
min_output_input_ratio=0.0):
|
|
||||||
"""Manifest Dataset
|
|
||||||
|
|
||||||
Args:
|
|
||||||
manifest_path (str): manifest josn file path
|
|
||||||
max_input_len ([type], optional): maximum output seq length,
|
|
||||||
in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf').
|
|
||||||
min_input_len (float, optional): minimum input seq length,
|
|
||||||
in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0.
|
|
||||||
max_output_len (float, optional): maximum input seq length,
|
|
||||||
in modeling units. Defaults to 500.0.
|
|
||||||
min_output_len (float, optional): minimum input seq length,
|
|
||||||
in modeling units. Defaults to 0.0.
|
|
||||||
max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio.
|
|
||||||
Defaults to 10.0.
|
|
||||||
min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio.
|
|
||||||
Defaults to 0.05.
|
|
||||||
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# read manifest
|
|
||||||
self._manifest = read_manifest(
|
|
||||||
manifest_path=manifest_path,
|
|
||||||
max_input_len=max_input_len,
|
|
||||||
min_input_len=min_input_len,
|
|
||||||
max_output_len=max_output_len,
|
|
||||||
min_output_len=min_output_len,
|
|
||||||
max_output_input_ratio=max_output_input_ratio,
|
|
||||||
min_output_input_ratio=min_output_input_ratio)
|
|
||||||
self._manifest.sort(key=lambda x: x["feat_shape"][0])
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self._manifest)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
instance = self._manifest[idx]
|
|
||||||
return instance["utt"], instance["feat"], instance["text"]
|
|
||||||
|
|
||||||
|
|
||||||
class TripletManifestDataset(ManifestDataset):
|
|
||||||
"""
|
|
||||||
For Joint Training of Speech Translation and ASR.
|
|
||||||
text: translation,
|
|
||||||
text1: transcript.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
instance = self._manifest[idx]
|
|
||||||
return instance["utt"], instance["feat"], instance["text"], instance[
|
|
||||||
"text1"]
|
|
||||||
|
|
||||||
|
|
||||||
class TransformDataset(Dataset):
|
|
||||||
"""Transform Dataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: list object from make_batchset
|
|
||||||
converter: batch function
|
|
||||||
reader: read data
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, data, converter, reader):
|
|
||||||
"""Init function."""
|
|
||||||
super().__init__()
|
|
||||||
self.data = data
|
|
||||||
self.converter = converter
|
|
||||||
self.reader = reader
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
"""Len function."""
|
|
||||||
return len(self.data)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
"""[] operator."""
|
|
||||||
return self.converter([self.reader(self.data[idx], return_uttid=True)])
|
|
||||||
@ -1,410 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
import kaldiio
|
|
||||||
import numpy as np
|
|
||||||
import soundfile
|
|
||||||
|
|
||||||
from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
|
|
||||||
__all__ = ["LoadInputsAndTargets"]
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
|
|
||||||
class LoadInputsAndTargets():
|
|
||||||
"""Create a mini-batch from a list of dicts
|
|
||||||
|
|
||||||
>>> batch = [('utt1',
|
|
||||||
... dict(input=[dict(feat='some.ark:123',
|
|
||||||
... filetype='mat',
|
|
||||||
... name='input1',
|
|
||||||
... shape=[100, 80])],
|
|
||||||
... output=[dict(tokenid='1 2 3 4',
|
|
||||||
... name='target1',
|
|
||||||
... shape=[4, 31])]]))
|
|
||||||
>>> l = LoadInputsAndTargets()
|
|
||||||
>>> feat, target = l(batch)
|
|
||||||
|
|
||||||
:param: str mode: Specify the task mode, "asr" or "tts"
|
|
||||||
:param: str preprocess_conf: The path of a json file for pre-processing
|
|
||||||
:param: bool load_input: If False, not to load the input data
|
|
||||||
:param: bool load_output: If False, not to load the output data
|
|
||||||
:param: bool sort_in_input_length: Sort the mini-batch in descending order
|
|
||||||
of the input length
|
|
||||||
:param: bool use_speaker_embedding: Used for tts mode only
|
|
||||||
:param: bool use_second_target: Used for tts mode only
|
|
||||||
:param: dict preprocess_args: Set some optional arguments for preprocessing
|
|
||||||
:param: Optional[dict] preprocess_args: Used for tts mode only
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
mode="asr",
|
|
||||||
preprocess_conf=None,
|
|
||||||
load_input=True,
|
|
||||||
load_output=True,
|
|
||||||
sort_in_input_length=True,
|
|
||||||
preprocess_args=None,
|
|
||||||
keep_all_data_on_mem=False, ):
|
|
||||||
self._loaders = {}
|
|
||||||
|
|
||||||
if mode not in ["asr"]:
|
|
||||||
raise ValueError("Only asr are allowed: mode={}".format(mode))
|
|
||||||
|
|
||||||
if preprocess_conf is not None:
|
|
||||||
with open(preprocess_conf, 'r') as fin:
|
|
||||||
self.preprocessing = AugmentationPipeline(fin.read())
|
|
||||||
logger.warning(
|
|
||||||
"[Experimental feature] Some preprocessing will be done "
|
|
||||||
"for the mini-batch creation using {}".format(
|
|
||||||
self.preprocessing))
|
|
||||||
else:
|
|
||||||
# If conf doesn't exist, this function don't touch anything.
|
|
||||||
self.preprocessing = None
|
|
||||||
|
|
||||||
self.mode = mode
|
|
||||||
self.load_output = load_output
|
|
||||||
self.load_input = load_input
|
|
||||||
self.sort_in_input_length = sort_in_input_length
|
|
||||||
if preprocess_args is None:
|
|
||||||
self.preprocess_args = {}
|
|
||||||
else:
|
|
||||||
assert isinstance(preprocess_args, dict), type(preprocess_args)
|
|
||||||
self.preprocess_args = dict(preprocess_args)
|
|
||||||
|
|
||||||
self.keep_all_data_on_mem = keep_all_data_on_mem
|
|
||||||
|
|
||||||
def __call__(self, batch, return_uttid=False):
|
|
||||||
"""Function to load inputs and targets from list of dicts
|
|
||||||
|
|
||||||
:param List[Tuple[str, dict]] batch: list of dict which is subset of
|
|
||||||
loaded data.json
|
|
||||||
:param bool return_uttid: return utterance ID information for visualization
|
|
||||||
:return: list of input token id sequences [(L_1), (L_2), ..., (L_B)]
|
|
||||||
:return: list of input feature sequences
|
|
||||||
[(T_1, D), (T_2, D), ..., (T_B, D)]
|
|
||||||
:rtype: list of float ndarray
|
|
||||||
:return: list of target token id sequences [(L_1), (L_2), ..., (L_B)]
|
|
||||||
:rtype: list of int ndarray
|
|
||||||
|
|
||||||
"""
|
|
||||||
x_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]]
|
|
||||||
y_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]]
|
|
||||||
uttid_list = [] # List[str]
|
|
||||||
|
|
||||||
for uttid, info in batch:
|
|
||||||
uttid_list.append(uttid)
|
|
||||||
|
|
||||||
if self.load_input:
|
|
||||||
# Note(kamo): This for-loop is for multiple inputs
|
|
||||||
for idx, inp in enumerate(info["input"]):
|
|
||||||
# {"input":
|
|
||||||
# [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
|
|
||||||
# "filetype": "hdf5",
|
|
||||||
# "name": "input1", ...}], ...}
|
|
||||||
x = self._get_from_loader(
|
|
||||||
filepath=inp["feat"],
|
|
||||||
filetype=inp.get("filetype", "mat"))
|
|
||||||
x_feats_dict.setdefault(inp["name"], []).append(x)
|
|
||||||
|
|
||||||
if self.load_output:
|
|
||||||
for idx, inp in enumerate(info["output"]):
|
|
||||||
if "tokenid" in inp:
|
|
||||||
# ======= Legacy format for output =======
|
|
||||||
# {"output": [{"tokenid": "1 2 3 4"}])
|
|
||||||
x = np.fromiter(
|
|
||||||
map(int, inp["tokenid"].split()), dtype=np.int64)
|
|
||||||
else:
|
|
||||||
# ======= New format =======
|
|
||||||
# {"input":
|
|
||||||
# [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
|
|
||||||
# "filetype": "hdf5",
|
|
||||||
# "name": "target1", ...}], ...}
|
|
||||||
x = self._get_from_loader(
|
|
||||||
filepath=inp["feat"],
|
|
||||||
filetype=inp.get("filetype", "mat"))
|
|
||||||
|
|
||||||
y_feats_dict.setdefault(inp["name"], []).append(x)
|
|
||||||
|
|
||||||
if self.mode == "asr":
|
|
||||||
return_batch, uttid_list = self._create_batch_asr(
|
|
||||||
x_feats_dict, y_feats_dict, uttid_list)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(self.mode)
|
|
||||||
|
|
||||||
if self.preprocessing is not None:
|
|
||||||
# Apply pre-processing all input features
|
|
||||||
for x_name in return_batch.keys():
|
|
||||||
if x_name.startswith("input"):
|
|
||||||
return_batch[x_name] = self.preprocessing(
|
|
||||||
return_batch[x_name], uttid_list,
|
|
||||||
**self.preprocess_args)
|
|
||||||
|
|
||||||
if return_uttid:
|
|
||||||
return tuple(return_batch.values()), uttid_list
|
|
||||||
|
|
||||||
# Doesn't return the names now.
|
|
||||||
return tuple(return_batch.values())
|
|
||||||
|
|
||||||
def _create_batch_asr(self, x_feats_dict, y_feats_dict, uttid_list):
|
|
||||||
"""Create a OrderedDict for the mini-batch
|
|
||||||
|
|
||||||
:param OrderedDict x_feats_dict:
|
|
||||||
e.g. {"input1": [ndarray, ndarray, ...],
|
|
||||||
"input2": [ndarray, ndarray, ...]}
|
|
||||||
:param OrderedDict y_feats_dict:
|
|
||||||
e.g. {"target1": [ndarray, ndarray, ...],
|
|
||||||
"target2": [ndarray, ndarray, ...]}
|
|
||||||
:param: List[str] uttid_list:
|
|
||||||
Give uttid_list to sort in the same order as the mini-batch
|
|
||||||
:return: batch, uttid_list
|
|
||||||
:rtype: Tuple[OrderedDict, List[str]]
|
|
||||||
"""
|
|
||||||
# handle single-input and multi-input (paralell) asr mode
|
|
||||||
xs = list(x_feats_dict.values())
|
|
||||||
|
|
||||||
if self.load_output:
|
|
||||||
ys = list(y_feats_dict.values())
|
|
||||||
assert len(xs[0]) == len(ys[0]), (len(xs[0]), len(ys[0]))
|
|
||||||
|
|
||||||
# get index of non-zero length samples
|
|
||||||
nonzero_idx = list(
|
|
||||||
filter(lambda i: len(ys[0][i]) > 0, range(len(ys[0]))))
|
|
||||||
for n in range(1, len(y_feats_dict)):
|
|
||||||
nonzero_idx = filter(lambda i: len(ys[n][i]) > 0, nonzero_idx)
|
|
||||||
else:
|
|
||||||
# Note(kamo): Be careful not to make nonzero_idx to a generator
|
|
||||||
nonzero_idx = list(range(len(xs[0])))
|
|
||||||
|
|
||||||
if self.sort_in_input_length:
|
|
||||||
# sort in input lengths based on the first input
|
|
||||||
nonzero_sorted_idx = sorted(
|
|
||||||
nonzero_idx, key=lambda i: -len(xs[0][i]))
|
|
||||||
else:
|
|
||||||
nonzero_sorted_idx = nonzero_idx
|
|
||||||
|
|
||||||
if len(nonzero_sorted_idx) != len(xs[0]):
|
|
||||||
logger.warning(
|
|
||||||
"Target sequences include empty tokenid (batch {} -> {}).".
|
|
||||||
format(len(xs[0]), len(nonzero_sorted_idx)))
|
|
||||||
|
|
||||||
# remove zero-length samples
|
|
||||||
xs = [[x[i] for i in nonzero_sorted_idx] for x in xs]
|
|
||||||
uttid_list = [uttid_list[i] for i in nonzero_sorted_idx]
|
|
||||||
|
|
||||||
x_names = list(x_feats_dict.keys())
|
|
||||||
if self.load_output:
|
|
||||||
ys = [[y[i] for i in nonzero_sorted_idx] for y in ys]
|
|
||||||
y_names = list(y_feats_dict.keys())
|
|
||||||
|
|
||||||
# Keeping x_name and y_name, e.g. input1, for future extension
|
|
||||||
return_batch = OrderedDict([
|
|
||||||
* [(x_name, x) for x_name, x in zip(x_names, xs)],
|
|
||||||
* [(y_name, y) for y_name, y in zip(y_names, ys)],
|
|
||||||
])
|
|
||||||
else:
|
|
||||||
return_batch = OrderedDict(
|
|
||||||
[(x_name, x) for x_name, x in zip(x_names, xs)])
|
|
||||||
return return_batch, uttid_list
|
|
||||||
|
|
||||||
def _get_from_loader(self, filepath, filetype):
|
|
||||||
"""Return ndarray
|
|
||||||
|
|
||||||
In order to make the fds to be opened only at the first referring,
|
|
||||||
the loader are stored in self._loaders
|
|
||||||
|
|
||||||
>>> ndarray = loader.get_from_loader(
|
|
||||||
... 'some/path.h5:F01_050C0101_PED_REAL', filetype='hdf5')
|
|
||||||
|
|
||||||
:param: str filepath:
|
|
||||||
:param: str filetype:
|
|
||||||
:return:
|
|
||||||
:rtype: np.ndarray
|
|
||||||
"""
|
|
||||||
if filetype == "hdf5":
|
|
||||||
# e.g.
|
|
||||||
# {"input": [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
|
|
||||||
# "filetype": "hdf5",
|
|
||||||
# -> filepath = "some/path.h5", key = "F01_050C0101_PED_REAL"
|
|
||||||
filepath, key = filepath.split(":", 1)
|
|
||||||
|
|
||||||
loader = self._loaders.get(filepath)
|
|
||||||
if loader is None:
|
|
||||||
# To avoid disk access, create loader only for the first time
|
|
||||||
loader = h5py.File(filepath, "r")
|
|
||||||
self._loaders[filepath] = loader
|
|
||||||
return loader[key][()]
|
|
||||||
elif filetype == "sound.hdf5":
|
|
||||||
# e.g.
|
|
||||||
# {"input": [{"feat": "some/path.h5:F01_050C0101_PED_REAL",
|
|
||||||
# "filetype": "sound.hdf5",
|
|
||||||
# -> filepath = "some/path.h5", key = "F01_050C0101_PED_REAL"
|
|
||||||
filepath, key = filepath.split(":", 1)
|
|
||||||
|
|
||||||
loader = self._loaders.get(filepath)
|
|
||||||
if loader is None:
|
|
||||||
# To avoid disk access, create loader only for the first time
|
|
||||||
loader = SoundHDF5File(filepath, "r", dtype="int16")
|
|
||||||
self._loaders[filepath] = loader
|
|
||||||
array, rate = loader[key]
|
|
||||||
return array
|
|
||||||
elif filetype == "sound":
|
|
||||||
# e.g.
|
|
||||||
# {"input": [{"feat": "some/path.wav",
|
|
||||||
# "filetype": "sound"},
|
|
||||||
# Assume PCM16
|
|
||||||
if not self.keep_all_data_on_mem:
|
|
||||||
array, _ = soundfile.read(filepath, dtype="int16")
|
|
||||||
return array
|
|
||||||
if filepath not in self._loaders:
|
|
||||||
array, _ = soundfile.read(filepath, dtype="int16")
|
|
||||||
self._loaders[filepath] = array
|
|
||||||
return self._loaders[filepath]
|
|
||||||
elif filetype == "npz":
|
|
||||||
# e.g.
|
|
||||||
# {"input": [{"feat": "some/path.npz:F01_050C0101_PED_REAL",
|
|
||||||
# "filetype": "npz",
|
|
||||||
filepath, key = filepath.split(":", 1)
|
|
||||||
|
|
||||||
loader = self._loaders.get(filepath)
|
|
||||||
if loader is None:
|
|
||||||
# To avoid disk access, create loader only for the first time
|
|
||||||
loader = np.load(filepath)
|
|
||||||
self._loaders[filepath] = loader
|
|
||||||
return loader[key]
|
|
||||||
elif filetype == "npy":
|
|
||||||
# e.g.
|
|
||||||
# {"input": [{"feat": "some/path.npy",
|
|
||||||
# "filetype": "npy"},
|
|
||||||
if not self.keep_all_data_on_mem:
|
|
||||||
return np.load(filepath)
|
|
||||||
if filepath not in self._loaders:
|
|
||||||
self._loaders[filepath] = np.load(filepath)
|
|
||||||
return self._loaders[filepath]
|
|
||||||
elif filetype in ["mat", "vec"]:
|
|
||||||
# e.g.
|
|
||||||
# {"input": [{"feat": "some/path.ark:123",
|
|
||||||
# "filetype": "mat"}]},
|
|
||||||
# In this case, "123" indicates the starting points of the matrix
|
|
||||||
# load_mat can load both matrix and vector
|
|
||||||
if not self.keep_all_data_on_mem:
|
|
||||||
return kaldiio.load_mat(filepath)
|
|
||||||
if filepath not in self._loaders:
|
|
||||||
self._loaders[filepath] = kaldiio.load_mat(filepath)
|
|
||||||
return self._loaders[filepath]
|
|
||||||
elif filetype == "scp":
|
|
||||||
# e.g.
|
|
||||||
# {"input": [{"feat": "some/path.scp:F01_050C0101_PED_REAL",
|
|
||||||
# "filetype": "scp",
|
|
||||||
filepath, key = filepath.split(":", 1)
|
|
||||||
loader = self._loaders.get(filepath)
|
|
||||||
if loader is None:
|
|
||||||
# To avoid disk access, create loader only for the first time
|
|
||||||
loader = kaldiio.load_scp(filepath)
|
|
||||||
self._loaders[filepath] = loader
|
|
||||||
return loader[key]
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Not supported: loader_type={}".format(filetype))
|
|
||||||
|
|
||||||
|
|
||||||
class SoundHDF5File():
|
|
||||||
"""Collecting sound files to a HDF5 file
|
|
||||||
|
|
||||||
>>> f = SoundHDF5File('a.flac.h5', mode='a')
|
|
||||||
>>> array = np.random.randint(0, 100, 100, dtype=np.int16)
|
|
||||||
>>> f['id'] = (array, 16000)
|
|
||||||
>>> array, rate = f['id']
|
|
||||||
|
|
||||||
|
|
||||||
:param: str filepath:
|
|
||||||
:param: str mode:
|
|
||||||
:param: str format: The type used when saving wav. flac, nist, htk, etc.
|
|
||||||
:param: str dtype:
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
filepath,
|
|
||||||
mode="r+",
|
|
||||||
format=None,
|
|
||||||
dtype="int16",
|
|
||||||
**kwargs):
|
|
||||||
self.filepath = filepath
|
|
||||||
self.mode = mode
|
|
||||||
self.dtype = dtype
|
|
||||||
|
|
||||||
self.file = h5py.File(filepath, mode, **kwargs)
|
|
||||||
if format is None:
|
|
||||||
# filepath = a.flac.h5 -> format = flac
|
|
||||||
second_ext = os.path.splitext(os.path.splitext(filepath)[0])[1]
|
|
||||||
format = second_ext[1:]
|
|
||||||
if format.upper() not in soundfile.available_formats():
|
|
||||||
# If not found, flac is selected
|
|
||||||
format = "flac"
|
|
||||||
|
|
||||||
# This format affects only saving
|
|
||||||
self.format = format
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return '<SoundHDF5 file "{}" (mode {}, format {}, type {})>'.format(
|
|
||||||
self.filepath, self.mode, self.format, self.dtype)
|
|
||||||
|
|
||||||
def create_dataset(self, name, shape=None, data=None, **kwds):
|
|
||||||
f = io.BytesIO()
|
|
||||||
array, rate = data
|
|
||||||
soundfile.write(f, array, rate, format=self.format)
|
|
||||||
self.file.create_dataset(
|
|
||||||
name, shape=shape, data=np.void(f.getvalue()), **kwds)
|
|
||||||
|
|
||||||
def __setitem__(self, name, data):
|
|
||||||
self.create_dataset(name, data=data)
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
data = self.file[key][()]
|
|
||||||
f = io.BytesIO(data.tobytes())
|
|
||||||
array, rate = soundfile.read(f, dtype=self.dtype)
|
|
||||||
return array, rate
|
|
||||||
|
|
||||||
def keys(self):
|
|
||||||
return self.file.keys()
|
|
||||||
|
|
||||||
def values(self):
|
|
||||||
for k in self.file:
|
|
||||||
yield self[k]
|
|
||||||
|
|
||||||
def items(self):
|
|
||||||
for k in self.file:
|
|
||||||
yield k, self[k]
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return iter(self.file)
|
|
||||||
|
|
||||||
def __contains__(self, item):
|
|
||||||
return item in self.file
|
|
||||||
|
|
||||||
def __len__(self, item):
|
|
||||||
return len(self.file)
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
self.file.close()
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
self.file.close()
|
|
||||||
@ -1,251 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import math
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from paddle import distributed as dist
|
|
||||||
from paddle.io import BatchSampler
|
|
||||||
from paddle.io import DistributedBatchSampler
|
|
||||||
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"SortagradDistributedBatchSampler",
|
|
||||||
"SortagradBatchSampler",
|
|
||||||
]
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
|
|
||||||
def _batch_shuffle(indices, batch_size, epoch, clipped=False):
|
|
||||||
"""Put similarly-sized instances into minibatches for better efficiency
|
|
||||||
and make a batch-wise shuffle.
|
|
||||||
|
|
||||||
1. Sort the audio clips by duration.
|
|
||||||
2. Generate a random number `k`, k in [0, batch_size).
|
|
||||||
3. Randomly shift `k` instances in order to create different batches
|
|
||||||
for different epochs. Create minibatches.
|
|
||||||
4. Shuffle the minibatches.
|
|
||||||
|
|
||||||
:param indices: indexes. List of int.
|
|
||||||
:type indices: list
|
|
||||||
:param batch_size: Batch size. This size is also used for generate
|
|
||||||
a random number for batch shuffle.
|
|
||||||
:type batch_size: int
|
|
||||||
:param clipped: Whether to clip the heading (small shift) and trailing
|
|
||||||
(incomplete batch) instances.
|
|
||||||
:type clipped: bool
|
|
||||||
:return: Batch shuffled mainifest.
|
|
||||||
:rtype: list
|
|
||||||
"""
|
|
||||||
rng = np.random.RandomState(epoch)
|
|
||||||
shift_len = rng.randint(0, batch_size - 1)
|
|
||||||
batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
|
|
||||||
rng.shuffle(batch_indices)
|
|
||||||
batch_indices = [item for batch in batch_indices for item in batch]
|
|
||||||
assert clipped is False
|
|
||||||
if not clipped:
|
|
||||||
res_len = len(indices) - shift_len - len(batch_indices)
|
|
||||||
# when res_len is 0, will return whole list, len(List[-0:]) = len(List[:])
|
|
||||||
if res_len != 0:
|
|
||||||
batch_indices.extend(indices[-res_len:])
|
|
||||||
batch_indices.extend(indices[0:shift_len])
|
|
||||||
assert len(indices) == len(
|
|
||||||
batch_indices
|
|
||||||
), f"_batch_shuffle: {len(indices)} : {len(batch_indices)} : {res_len} - {shift_len}"
|
|
||||||
return batch_indices
|
|
||||||
|
|
||||||
|
|
||||||
class SortagradDistributedBatchSampler(DistributedBatchSampler):
|
|
||||||
def __init__(self,
|
|
||||||
dataset,
|
|
||||||
batch_size,
|
|
||||||
num_replicas=None,
|
|
||||||
rank=None,
|
|
||||||
shuffle=False,
|
|
||||||
drop_last=False,
|
|
||||||
sortagrad=False,
|
|
||||||
shuffle_method="batch_shuffle"):
|
|
||||||
"""Sortagrad Sampler for multi gpus.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset (paddle.io.Dataset):
|
|
||||||
batch_size (int): batch size for one gpu
|
|
||||||
num_replicas (int, optional): world size or numbers of gpus. Defaults to None.
|
|
||||||
rank (int, optional): rank id. Defaults to None.
|
|
||||||
shuffle (bool, optional): True for do shuffle, or else. Defaults to False.
|
|
||||||
drop_last (bool, optional): whether drop last batch which is less than batch size. Defaults to False.
|
|
||||||
sortagrad (bool, optional): True, do sortgrad in first epoch, then shuffle as usual; or else. Defaults to False.
|
|
||||||
shuffle_method (str, optional): shuffle method, "instance_shuffle" or "batch_shuffle". Defaults to "batch_shuffle".
|
|
||||||
"""
|
|
||||||
super().__init__(dataset, batch_size, num_replicas, rank, shuffle,
|
|
||||||
drop_last)
|
|
||||||
self._sortagrad = sortagrad
|
|
||||||
self._shuffle_method = shuffle_method
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
num_samples = len(self.dataset)
|
|
||||||
indices = np.arange(num_samples).tolist()
|
|
||||||
indices += indices[:(self.total_size - len(indices))]
|
|
||||||
assert len(indices) == self.total_size
|
|
||||||
|
|
||||||
# sort (by duration) or batch-wise shuffle the manifest
|
|
||||||
if self.shuffle:
|
|
||||||
if self.epoch == 0 and self._sortagrad:
|
|
||||||
logger.info(
|
|
||||||
f'rank: {dist.get_rank()} dataset sortagrad! epoch {self.epoch}'
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
f'rank: {dist.get_rank()} dataset shuffle! epoch {self.epoch}'
|
|
||||||
)
|
|
||||||
if self._shuffle_method == "batch_shuffle":
|
|
||||||
# using `batch_size * nrank`, or will cause instability loss and nan or inf grad,
|
|
||||||
# since diff batch examlpe length in batches case instability loss in diff rank,
|
|
||||||
# e.g. rank0 maxlength 20, rank3 maxlength 1000
|
|
||||||
indices = _batch_shuffle(
|
|
||||||
indices,
|
|
||||||
self.batch_size * self.nranks,
|
|
||||||
self.epoch,
|
|
||||||
clipped=False)
|
|
||||||
elif self._shuffle_method == "instance_shuffle":
|
|
||||||
np.random.RandomState(self.epoch).shuffle(indices)
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown shuffle method %s." %
|
|
||||||
self._shuffle_method)
|
|
||||||
assert len(
|
|
||||||
indices
|
|
||||||
) == self.total_size, f"batch shuffle examples error: {len(indices)} : {self.total_size}"
|
|
||||||
|
|
||||||
# slice `self.batch_size` examples by rank id
|
|
||||||
def _get_indices_by_batch_size(indices):
|
|
||||||
subsampled_indices = []
|
|
||||||
last_batch_size = self.total_size % (self.batch_size * self.nranks)
|
|
||||||
assert last_batch_size % self.nranks == 0
|
|
||||||
last_local_batch_size = last_batch_size // self.nranks
|
|
||||||
|
|
||||||
for i in range(self.local_rank * self.batch_size,
|
|
||||||
len(indices) - last_batch_size,
|
|
||||||
self.batch_size * self.nranks):
|
|
||||||
subsampled_indices.extend(indices[i:i + self.batch_size])
|
|
||||||
|
|
||||||
indices = indices[len(indices) - last_batch_size:]
|
|
||||||
subsampled_indices.extend(
|
|
||||||
indices[self.local_rank * last_local_batch_size:(
|
|
||||||
self.local_rank + 1) * last_local_batch_size])
|
|
||||||
return subsampled_indices
|
|
||||||
|
|
||||||
if self.nranks > 1:
|
|
||||||
indices = _get_indices_by_batch_size(indices)
|
|
||||||
|
|
||||||
assert len(indices) == self.num_samples
|
|
||||||
_sample_iter = iter(indices)
|
|
||||||
|
|
||||||
batch_indices = []
|
|
||||||
for idx in _sample_iter:
|
|
||||||
batch_indices.append(idx)
|
|
||||||
if len(batch_indices) == self.batch_size:
|
|
||||||
logger.debug(
|
|
||||||
f"rank: {dist.get_rank()} batch index: {batch_indices} ")
|
|
||||||
yield batch_indices
|
|
||||||
batch_indices = []
|
|
||||||
if not self.drop_last and len(batch_indices) > 0:
|
|
||||||
yield batch_indices
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
num_samples = self.num_samples
|
|
||||||
num_samples += int(not self.drop_last) * (self.batch_size - 1)
|
|
||||||
return num_samples // self.batch_size
|
|
||||||
|
|
||||||
|
|
||||||
class SortagradBatchSampler(BatchSampler):
|
|
||||||
def __init__(self,
|
|
||||||
dataset,
|
|
||||||
batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
drop_last=False,
|
|
||||||
sortagrad=False,
|
|
||||||
shuffle_method="batch_shuffle"):
|
|
||||||
"""Sortagrad Sampler for one gpu.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset (paddle.io.Dataset):
|
|
||||||
batch_size (int): batch size for one gpu
|
|
||||||
shuffle (bool, optional): True for do shuffle, or else. Defaults to False.
|
|
||||||
drop_last (bool, optional): whether drop last batch which is less than batch size. Defaults to False.
|
|
||||||
sortagrad (bool, optional): True, do sortgrad in first epoch, then shuffle as usual; or else. Defaults to False.
|
|
||||||
shuffle_method (str, optional): shuffle method, "instance_shuffle" or "batch_shuffle". Defaults to "batch_shuffle".
|
|
||||||
"""
|
|
||||||
self.dataset = dataset
|
|
||||||
|
|
||||||
assert isinstance(batch_size, int) and batch_size > 0, \
|
|
||||||
"batch_size should be a positive integer"
|
|
||||||
self.batch_size = batch_size
|
|
||||||
assert isinstance(shuffle, bool), \
|
|
||||||
"shuffle should be a boolean value"
|
|
||||||
self.shuffle = shuffle
|
|
||||||
assert isinstance(drop_last, bool), \
|
|
||||||
"drop_last should be a boolean number"
|
|
||||||
|
|
||||||
self.drop_last = drop_last
|
|
||||||
self.epoch = 0
|
|
||||||
self.num_samples = int(math.ceil(len(self.dataset) * 1.0))
|
|
||||||
self.total_size = self.num_samples
|
|
||||||
self._sortagrad = sortagrad
|
|
||||||
self._shuffle_method = shuffle_method
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
num_samples = len(self.dataset)
|
|
||||||
indices = np.arange(num_samples).tolist()
|
|
||||||
indices += indices[:(self.total_size - len(indices))]
|
|
||||||
assert len(indices) == self.total_size
|
|
||||||
|
|
||||||
# sort (by duration) or batch-wise shuffle the manifest
|
|
||||||
if self.shuffle:
|
|
||||||
if self.epoch == 0 and self._sortagrad:
|
|
||||||
logger.info(f'dataset sortagrad! epoch {self.epoch}')
|
|
||||||
else:
|
|
||||||
logger.info(f'dataset shuffle! epoch {self.epoch}')
|
|
||||||
if self._shuffle_method == "batch_shuffle":
|
|
||||||
indices = _batch_shuffle(
|
|
||||||
indices, self.batch_size, self.epoch, clipped=False)
|
|
||||||
elif self._shuffle_method == "instance_shuffle":
|
|
||||||
np.random.RandomState(self.epoch).shuffle(indices)
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown shuffle method %s." %
|
|
||||||
self._shuffle_method)
|
|
||||||
assert len(
|
|
||||||
indices
|
|
||||||
) == self.total_size, f"batch shuffle examples error: {len(indices)} : {self.total_size}"
|
|
||||||
|
|
||||||
assert len(indices) == self.num_samples
|
|
||||||
_sample_iter = iter(indices)
|
|
||||||
|
|
||||||
batch_indices = []
|
|
||||||
for idx in _sample_iter:
|
|
||||||
batch_indices.append(idx)
|
|
||||||
if len(batch_indices) == self.batch_size:
|
|
||||||
logger.debug(
|
|
||||||
f"rank: {dist.get_rank()} batch index: {batch_indices} ")
|
|
||||||
yield batch_indices
|
|
||||||
batch_indices = []
|
|
||||||
if not self.drop_last and len(batch_indices) > 0:
|
|
||||||
yield batch_indices
|
|
||||||
|
|
||||||
self.epoch += 1
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
num_samples = self.num_samples
|
|
||||||
num_samples += int(not self.drop_last) * (self.batch_size - 1)
|
|
||||||
return num_samples // self.batch_size
|
|
||||||
@ -1,87 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
|
|
||||||
__all__ = ["pad_list", "pad_sequence"]
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
|
|
||||||
def pad_list(sequences: List[np.ndarray],
|
|
||||||
padding_value: float=0.0) -> np.ndarray:
|
|
||||||
return pad_sequence(sequences, True, padding_value)
|
|
||||||
|
|
||||||
|
|
||||||
def pad_sequence(sequences: List[np.ndarray],
|
|
||||||
batch_first: bool=True,
|
|
||||||
padding_value: float=0.0) -> np.ndarray:
|
|
||||||
r"""Pad a list of variable length Tensors with ``padding_value``
|
|
||||||
|
|
||||||
``pad_sequence`` stacks a list of Tensors along a new dimension,
|
|
||||||
and pads them to equal length. For example, if the input is list of
|
|
||||||
sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
|
|
||||||
otherwise.
|
|
||||||
|
|
||||||
`B` is batch size. It is equal to the number of elements in ``sequences``.
|
|
||||||
`T` is length of the longest sequence.
|
|
||||||
`L` is length of the sequence.
|
|
||||||
`*` is any number of trailing dimensions, including none.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> a = np.ones([25, 300])
|
|
||||||
>>> b = np.ones([22, 300])
|
|
||||||
>>> c = np.ones([15, 300])
|
|
||||||
>>> pad_sequence([a, b, c]).shape
|
|
||||||
[25, 3, 300]
|
|
||||||
|
|
||||||
Note:
|
|
||||||
This function returns a np.ndarray of size ``T x B x *`` or ``B x T x *``
|
|
||||||
where `T` is the length of the longest sequence. This function assumes
|
|
||||||
trailing dimensions and type of all the Tensors in sequences are same.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sequences (list[np.ndarray]): list of variable length sequences.
|
|
||||||
batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
|
|
||||||
``T x B x *`` otherwise
|
|
||||||
padding_value (float, optional): value for padded elements. Default: 0.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray of size ``T x B x *`` if :attr:`batch_first` is ``False``.
|
|
||||||
np.ndarray of size ``B x T x *`` otherwise
|
|
||||||
"""
|
|
||||||
|
|
||||||
# assuming trailing dimensions and type of all the Tensors
|
|
||||||
# in sequences are same and fetching those from sequences[0]
|
|
||||||
max_size = sequences[0].shape
|
|
||||||
trailing_dims = max_size[1:]
|
|
||||||
max_len = max([s.shape[0] for s in sequences])
|
|
||||||
if batch_first:
|
|
||||||
out_dims = (len(sequences), max_len) + trailing_dims
|
|
||||||
else:
|
|
||||||
out_dims = (max_len, len(sequences)) + trailing_dims
|
|
||||||
|
|
||||||
out_tensor = np.full(out_dims, padding_value, dtype=sequences[0].dtype)
|
|
||||||
for i, tensor in enumerate(sequences):
|
|
||||||
length = tensor.shape[0]
|
|
||||||
# use index notation to prevent duplicate references to the tensor
|
|
||||||
if batch_first:
|
|
||||||
out_tensor[i, :length, ...] = tensor
|
|
||||||
else:
|
|
||||||
out_tensor[:length, i, ...] = tensor
|
|
||||||
|
|
||||||
return out_tensor
|
|
||||||
@ -1,13 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
@ -1,165 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
from paddle import nn
|
|
||||||
from paddle.nn import functional as F
|
|
||||||
|
|
||||||
from deepspeech.modules.activation import brelu
|
|
||||||
from deepspeech.modules.mask import make_non_pad_mask
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
__all__ = ['ConvStack', "conv_output_size"]
|
|
||||||
|
|
||||||
|
|
||||||
def conv_output_size(I, F, P, S):
|
|
||||||
# https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks#hyperparameters
|
|
||||||
# Output size after Conv:
|
|
||||||
# By noting I the length of the input volume size,
|
|
||||||
# F the length of the filter,
|
|
||||||
# P the amount of zero padding,
|
|
||||||
# S the stride,
|
|
||||||
# then the output size O of the feature map along that dimension is given by:
|
|
||||||
# O = (I - F + Pstart + Pend) // S + 1
|
|
||||||
# When Pstart == Pend == P, we can replace Pstart + Pend by 2P.
|
|
||||||
# When Pstart == Pend == 0
|
|
||||||
# O = (I - F - S) // S
|
|
||||||
# https://iq.opengenus.org/output-size-of-convolution/
|
|
||||||
# Output height = (Input height + padding height top + padding height bottom - kernel height) / (stride height) + 1
|
|
||||||
# Output width = (Output width + padding width right + padding width left - kernel width) / (stride width) + 1
|
|
||||||
return (I - F + 2 * P - S) // S
|
|
||||||
|
|
||||||
|
|
||||||
class ConvBn(nn.Layer):
|
|
||||||
"""Convolution layer with batch normalization.
|
|
||||||
|
|
||||||
:param kernel_size: The x dimension of a filter kernel. Or input a tuple for
|
|
||||||
two image dimension.
|
|
||||||
:type kernel_size: int|tuple|list
|
|
||||||
:param num_channels_in: Number of input channels.
|
|
||||||
:type num_channels_in: int
|
|
||||||
:param num_channels_out: Number of output channels.
|
|
||||||
:type num_channels_out: int
|
|
||||||
:param stride: The x dimension of the stride. Or input a tuple for two
|
|
||||||
image dimension.
|
|
||||||
:type stride: int|tuple|list
|
|
||||||
:param padding: The x dimension of the padding. Or input a tuple for two
|
|
||||||
image dimension.
|
|
||||||
:type padding: int|tuple|list
|
|
||||||
:param act: Activation type, relu|brelu
|
|
||||||
:type act: string
|
|
||||||
:return: Batch norm layer after convolution layer.
|
|
||||||
:rtype: Variable
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, num_channels_in, num_channels_out, kernel_size, stride,
|
|
||||||
padding, act):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
assert len(kernel_size) == 2
|
|
||||||
assert len(stride) == 2
|
|
||||||
assert len(padding) == 2
|
|
||||||
self.kernel_size = kernel_size
|
|
||||||
self.stride = stride
|
|
||||||
self.padding = padding
|
|
||||||
|
|
||||||
self.conv = nn.Conv2D(
|
|
||||||
num_channels_in,
|
|
||||||
num_channels_out,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
stride=stride,
|
|
||||||
padding=padding,
|
|
||||||
weight_attr=None,
|
|
||||||
bias_attr=False,
|
|
||||||
data_format='NCHW')
|
|
||||||
|
|
||||||
self.bn = nn.BatchNorm2D(
|
|
||||||
num_channels_out,
|
|
||||||
weight_attr=None,
|
|
||||||
bias_attr=None,
|
|
||||||
data_format='NCHW')
|
|
||||||
self.act = F.relu if act == 'relu' else brelu
|
|
||||||
|
|
||||||
def forward(self, x, x_len):
|
|
||||||
"""
|
|
||||||
x(Tensor): audio, shape [B, C, D, T]
|
|
||||||
"""
|
|
||||||
x = self.conv(x)
|
|
||||||
x = self.bn(x)
|
|
||||||
x = self.act(x)
|
|
||||||
|
|
||||||
x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1]
|
|
||||||
) // self.stride[1] + 1
|
|
||||||
|
|
||||||
# reset padding part to 0
|
|
||||||
masks = make_non_pad_mask(x_len) #[B, T]
|
|
||||||
masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]
|
|
||||||
# TODO(Hui Zhang): not support bool multiply
|
|
||||||
# masks = masks.type_as(x)
|
|
||||||
masks = masks.astype(x.dtype)
|
|
||||||
x = x.multiply(masks)
|
|
||||||
|
|
||||||
return x, x_len
|
|
||||||
|
|
||||||
|
|
||||||
class ConvStack(nn.Layer):
|
|
||||||
"""Convolution group with stacked convolution layers.
|
|
||||||
|
|
||||||
:param feat_size: audio feature dim.
|
|
||||||
:type feat_size: int
|
|
||||||
:param num_stacks: Number of stacked convolution layers.
|
|
||||||
:type num_stacks: int
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, feat_size, num_stacks):
|
|
||||||
super().__init__()
|
|
||||||
self.feat_size = feat_size # D
|
|
||||||
self.num_stacks = num_stacks
|
|
||||||
|
|
||||||
self.conv_in = ConvBn(
|
|
||||||
num_channels_in=1,
|
|
||||||
num_channels_out=32,
|
|
||||||
kernel_size=(41, 11), #[D, T]
|
|
||||||
stride=(2, 3),
|
|
||||||
padding=(20, 5),
|
|
||||||
act='brelu')
|
|
||||||
|
|
||||||
out_channel = 32
|
|
||||||
convs = [
|
|
||||||
ConvBn(
|
|
||||||
num_channels_in=32,
|
|
||||||
num_channels_out=out_channel,
|
|
||||||
kernel_size=(21, 11),
|
|
||||||
stride=(2, 1),
|
|
||||||
padding=(10, 5),
|
|
||||||
act='brelu') for i in range(num_stacks - 1)
|
|
||||||
]
|
|
||||||
self.conv_stack = nn.LayerList(convs)
|
|
||||||
|
|
||||||
# conv output feat_dim
|
|
||||||
output_height = (feat_size - 1) // 2 + 1
|
|
||||||
for i in range(self.num_stacks - 1):
|
|
||||||
output_height = (output_height - 1) // 2 + 1
|
|
||||||
self.output_height = out_channel * output_height
|
|
||||||
|
|
||||||
def forward(self, x, x_len):
|
|
||||||
"""
|
|
||||||
x: shape [B, C, D, T]
|
|
||||||
x_len : shape [B]
|
|
||||||
"""
|
|
||||||
x, x_len = self.conv_in(x, x_len)
|
|
||||||
for i, conv in enumerate(self.conv_stack):
|
|
||||||
x, x_len = conv(x, x_len)
|
|
||||||
return x, x_len
|
|
||||||
@ -1,13 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
@ -1,145 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
import paddle
|
|
||||||
from paddle import nn
|
|
||||||
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
__all__ = ["get_activation", "brelu", "LinearGLUBlock", "ConvGLUBlock"]
|
|
||||||
|
|
||||||
|
|
||||||
def brelu(x, t_min=0.0, t_max=24.0, name=None):
|
|
||||||
# paddle.to_tensor is dygraph_only can not work under JIT
|
|
||||||
t_min = paddle.full(shape=[1], fill_value=t_min, dtype='float32')
|
|
||||||
t_max = paddle.full(shape=[1], fill_value=t_max, dtype='float32')
|
|
||||||
return x.maximum(t_min).minimum(t_max)
|
|
||||||
|
|
||||||
|
|
||||||
class LinearGLUBlock(nn.Layer):
|
|
||||||
"""A linear Gated Linear Units (GLU) block."""
|
|
||||||
|
|
||||||
def __init__(self, idim: int):
|
|
||||||
""" GLU.
|
|
||||||
Args:
|
|
||||||
idim (int): input and output dimension
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.fc = nn.Linear(idim, idim * 2)
|
|
||||||
|
|
||||||
def forward(self, xs):
|
|
||||||
return glu(self.fc(xs), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
class ConvGLUBlock(nn.Layer):
|
|
||||||
def __init__(self, kernel_size, in_ch, out_ch, bottlececk_dim=0,
|
|
||||||
dropout=0.):
|
|
||||||
"""A convolutional Gated Linear Units (GLU) block.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
kernel_size (int): kernel size
|
|
||||||
in_ch (int): number of input channels
|
|
||||||
out_ch (int): number of output channels
|
|
||||||
bottlececk_dim (int): dimension of the bottleneck layers for computational efficiency. Defaults to 0.
|
|
||||||
dropout (float): dropout probability. Defaults to 0..
|
|
||||||
"""
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.conv_residual = None
|
|
||||||
if in_ch != out_ch:
|
|
||||||
self.conv_residual = nn.utils.weight_norm(
|
|
||||||
nn.Conv2D(
|
|
||||||
in_channels=in_ch, out_channels=out_ch, kernel_size=(1, 1)),
|
|
||||||
name='weight',
|
|
||||||
dim=0)
|
|
||||||
self.dropout_residual = nn.Dropout(p=dropout)
|
|
||||||
|
|
||||||
self.pad_left = nn.Pad2d((0, 0, kernel_size - 1, 0), 0)
|
|
||||||
|
|
||||||
layers = OrderedDict()
|
|
||||||
if bottlececk_dim == 0:
|
|
||||||
layers['conv'] = nn.utils.weight_norm(
|
|
||||||
nn.Conv2D(
|
|
||||||
in_channels=in_ch,
|
|
||||||
out_channels=out_ch * 2,
|
|
||||||
kernel_size=(kernel_size, 1)),
|
|
||||||
name='weight',
|
|
||||||
dim=0)
|
|
||||||
# TODO(hirofumi0810): padding?
|
|
||||||
layers['dropout'] = nn.Dropout(p=dropout)
|
|
||||||
layers['glu'] = GLU()
|
|
||||||
|
|
||||||
elif bottlececk_dim > 0:
|
|
||||||
layers['conv_in'] = nn.utils.weight_norm(
|
|
||||||
nn.Conv2D(
|
|
||||||
in_channels=in_ch,
|
|
||||||
out_channels=bottlececk_dim,
|
|
||||||
kernel_size=(1, 1)),
|
|
||||||
name='weight',
|
|
||||||
dim=0)
|
|
||||||
layers['dropout_in'] = nn.Dropout(p=dropout)
|
|
||||||
layers['conv_bottleneck'] = nn.utils.weight_norm(
|
|
||||||
nn.Conv2D(
|
|
||||||
in_channels=bottlececk_dim,
|
|
||||||
out_channels=bottlececk_dim,
|
|
||||||
kernel_size=(kernel_size, 1)),
|
|
||||||
name='weight',
|
|
||||||
dim=0)
|
|
||||||
layers['dropout'] = nn.Dropout(p=dropout)
|
|
||||||
layers['glu'] = GLU()
|
|
||||||
layers['conv_out'] = nn.utils.weight_norm(
|
|
||||||
nn.Conv2D(
|
|
||||||
in_channels=bottlececk_dim,
|
|
||||||
out_channels=out_ch * 2,
|
|
||||||
kernel_size=(1, 1)),
|
|
||||||
name='weight',
|
|
||||||
dim=0)
|
|
||||||
layers['dropout_out'] = nn.Dropout(p=dropout)
|
|
||||||
|
|
||||||
self.layers = nn.Sequential(layers)
|
|
||||||
|
|
||||||
def forward(self, xs):
|
|
||||||
"""Forward pass.
|
|
||||||
Args:
|
|
||||||
xs (FloatTensor): `[B, in_ch, T, feat_dim]`
|
|
||||||
Returns:
|
|
||||||
out (FloatTensor): `[B, out_ch, T, feat_dim]`
|
|
||||||
"""
|
|
||||||
residual = xs
|
|
||||||
if self.conv_residual is not None:
|
|
||||||
residual = self.dropout_residual(self.conv_residual(residual))
|
|
||||||
xs = self.pad_left(xs) # `[B, embed_dim, T+kernel-1, 1]`
|
|
||||||
xs = self.layers(xs) # `[B, out_ch * 2, T ,1]`
|
|
||||||
xs = xs + residual
|
|
||||||
return xs
|
|
||||||
|
|
||||||
|
|
||||||
def get_activation(act):
|
|
||||||
"""Return activation function."""
|
|
||||||
# Lazy load to avoid unused import
|
|
||||||
activation_funcs = {
|
|
||||||
"hardtanh": paddle.nn.Hardtanh,
|
|
||||||
"tanh": paddle.nn.Tanh,
|
|
||||||
"relu": paddle.nn.ReLU,
|
|
||||||
"selu": paddle.nn.SELU,
|
|
||||||
"swish": paddle.nn.Swish,
|
|
||||||
"gelu": paddle.nn.GELU,
|
|
||||||
"brelu": brelu,
|
|
||||||
}
|
|
||||||
|
|
||||||
return activation_funcs[act]()
|
|
||||||
@ -1,51 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import paddle
|
|
||||||
from paddle import nn
|
|
||||||
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
__all__ = ['GlobalCMVN']
|
|
||||||
|
|
||||||
|
|
||||||
class GlobalCMVN(nn.Layer):
|
|
||||||
def __init__(self,
|
|
||||||
mean: paddle.Tensor,
|
|
||||||
istd: paddle.Tensor,
|
|
||||||
norm_var: bool=True):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
mean (paddle.Tensor): mean stats
|
|
||||||
istd (paddle.Tensor): inverse std, std which is 1.0 / std
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
assert mean.shape == istd.shape
|
|
||||||
self.norm_var = norm_var
|
|
||||||
# The buffer can be accessed from this module using self.mean
|
|
||||||
self.register_buffer("mean", mean)
|
|
||||||
self.register_buffer("istd", istd)
|
|
||||||
|
|
||||||
def forward(self, x: paddle.Tensor):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x (paddle.Tensor): (batch, max_len, feat_dim)
|
|
||||||
Returns:
|
|
||||||
(paddle.Tensor): normalized feature
|
|
||||||
"""
|
|
||||||
x = x - self.mean
|
|
||||||
if self.norm_var:
|
|
||||||
x = x * self.istd
|
|
||||||
return x
|
|
||||||
@ -1,370 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import paddle
|
|
||||||
from paddle import nn
|
|
||||||
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
__all__ = ['CRF']
|
|
||||||
|
|
||||||
|
|
||||||
class CRF(nn.Layer):
|
|
||||||
"""
|
|
||||||
Linear-chain Conditional Random Field (CRF).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
nb_labels (int): number of labels in your tagset, including special symbols.
|
|
||||||
bos_tag_id (int): integer representing the beginning of sentence symbol in
|
|
||||||
your tagset.
|
|
||||||
eos_tag_id (int): integer representing the end of sentence symbol in your tagset.
|
|
||||||
pad_tag_id (int, optional): integer representing the pad symbol in your tagset.
|
|
||||||
If None, the model will treat the PAD as a normal tag. Otherwise, the model
|
|
||||||
will apply constraints for PAD transitions.
|
|
||||||
batch_first (bool): Whether the first dimension represents the batch dimension.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
nb_labels: int,
|
|
||||||
bos_tag_id: int,
|
|
||||||
eos_tag_id: int,
|
|
||||||
pad_tag_id: int=None,
|
|
||||||
batch_first: bool=True):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.nb_labels = nb_labels
|
|
||||||
self.BOS_TAG_ID = bos_tag_id
|
|
||||||
self.EOS_TAG_ID = eos_tag_id
|
|
||||||
self.PAD_TAG_ID = pad_tag_id
|
|
||||||
self.batch_first = batch_first
|
|
||||||
|
|
||||||
# initialize transitions from a random uniform distribution between -0.1 and 0.1
|
|
||||||
self.transitions = self.create_parameter(
|
|
||||||
[self.nb_labels, self.nb_labels],
|
|
||||||
default_initializer=nn.initializer.Uniform(-0.1, 0.1))
|
|
||||||
self.init_weights()
|
|
||||||
|
|
||||||
def init_weights(self):
|
|
||||||
# enforce contraints (rows=from, columns=to) with a big negative number
|
|
||||||
# so exp(-10000) will tend to zero
|
|
||||||
|
|
||||||
# no transitions allowed to the beginning of sentence
|
|
||||||
self.transitions[:, self.BOS_TAG_ID] = -10000.0
|
|
||||||
# no transition alloed from the end of sentence
|
|
||||||
self.transitions[self.EOS_TAG_ID, :] = -10000.0
|
|
||||||
|
|
||||||
if self.PAD_TAG_ID is not None:
|
|
||||||
# no transitions from padding
|
|
||||||
self.transitions[self.PAD_TAG_ID, :] = -10000.0
|
|
||||||
# no transitions to padding
|
|
||||||
self.transitions[:, self.PAD_TAG_ID] = -10000.0
|
|
||||||
# except if the end of sentence is reached
|
|
||||||
# or we are already in a pad position
|
|
||||||
self.transitions[self.PAD_TAG_ID, self.EOS_TAG_ID] = 0.0
|
|
||||||
self.transitions[self.PAD_TAG_ID, self.PAD_TAG_ID] = 0.0
|
|
||||||
|
|
||||||
def forward(self,
|
|
||||||
emissions: paddle.Tensor,
|
|
||||||
tags: paddle.Tensor,
|
|
||||||
mask: paddle.Tensor=None) -> paddle.Tensor:
|
|
||||||
"""Compute the negative log-likelihood. See `log_likelihood` method."""
|
|
||||||
nll = -self.log_likelihood(emissions, tags, mask=mask)
|
|
||||||
return nll
|
|
||||||
|
|
||||||
def log_likelihood(self, emissions, tags, mask=None):
|
|
||||||
"""Compute the probability of a sequence of tags given a sequence of
|
|
||||||
emissions scores.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
emissions (paddle.Tensor): Sequence of emissions for each label.
|
|
||||||
Shape of (batch_size, seq_len, nb_labels) if batch_first is True,
|
|
||||||
(seq_len, batch_size, nb_labels) otherwise.
|
|
||||||
tags (paddle.LongTensor): Sequence of labels.
|
|
||||||
Shape of (batch_size, seq_len) if batch_first is True,
|
|
||||||
(seq_len, batch_size) otherwise.
|
|
||||||
mask (paddle.FloatTensor, optional): Tensor representing valid positions.
|
|
||||||
If None, all positions are considered valid.
|
|
||||||
Shape of (batch_size, seq_len) if batch_first is True,
|
|
||||||
(seq_len, batch_size) otherwise.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
paddle.Tensor: sum of the log-likelihoods for each sequence in the batch.
|
|
||||||
Shape of ()
|
|
||||||
"""
|
|
||||||
# fix tensors order by setting batch as the first dimension
|
|
||||||
if not self.batch_first:
|
|
||||||
emissions = emissions.transpose(0, 1)
|
|
||||||
tags = tags.transpose(0, 1)
|
|
||||||
|
|
||||||
if mask is None:
|
|
||||||
mask = paddle.ones(emissions.shape[:2], dtype=paddle.float)
|
|
||||||
|
|
||||||
scores = self._compute_scores(emissions, tags, mask=mask)
|
|
||||||
partition = self._compute_log_partition(emissions, mask=mask)
|
|
||||||
return paddle.sum(scores - partition)
|
|
||||||
|
|
||||||
def decode(self, emissions, mask=None):
|
|
||||||
"""Find the most probable sequence of labels given the emissions using
|
|
||||||
the Viterbi algorithm.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
emissions (paddle.Tensor): Sequence of emissions for each label.
|
|
||||||
Shape (batch_size, seq_len, nb_labels) if batch_first is True,
|
|
||||||
(seq_len, batch_size, nb_labels) otherwise.
|
|
||||||
mask (paddle.FloatTensor, optional): Tensor representing valid positions.
|
|
||||||
If None, all positions are considered valid.
|
|
||||||
Shape (batch_size, seq_len) if batch_first is True,
|
|
||||||
(seq_len, batch_size) otherwise.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
paddle.Tensor: the viterbi score for the for each batch.
|
|
||||||
Shape of (batch_size,)
|
|
||||||
list of lists: the best viterbi sequence of labels for each batch. [B, T]
|
|
||||||
"""
|
|
||||||
# fix tensors order by setting batch as the first dimension
|
|
||||||
if not self.batch_first:
|
|
||||||
emissions = emissions.transpose(0, 1)
|
|
||||||
tags = tags.transpose(0, 1)
|
|
||||||
|
|
||||||
if mask is None:
|
|
||||||
mask = paddle.ones(emissions.shape[:2], dtype=paddle.float)
|
|
||||||
|
|
||||||
scores, sequences = self._viterbi_decode(emissions, mask)
|
|
||||||
return scores, sequences
|
|
||||||
|
|
||||||
def _compute_scores(self, emissions, tags, mask):
|
|
||||||
"""Compute the scores for a given batch of emissions with their tags.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
emissions (paddle.Tensor): (batch_size, seq_len, nb_labels)
|
|
||||||
tags (Paddle.LongTensor): (batch_size, seq_len)
|
|
||||||
mask (Paddle.FloatTensor): (batch_size, seq_len)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
paddle.Tensor: Scores for each batch.
|
|
||||||
Shape of (batch_size,)
|
|
||||||
"""
|
|
||||||
batch_size, seq_length = tags.shape
|
|
||||||
scores = paddle.zeros([batch_size])
|
|
||||||
|
|
||||||
# save first and last tags to be used later
|
|
||||||
first_tags = tags[:, 0]
|
|
||||||
last_valid_idx = mask.int().sum(1) - 1
|
|
||||||
|
|
||||||
# TODO(Hui Zhang): not support fancy index.
|
|
||||||
# last_tags = tags.gather(last_valid_idx.unsqueeze(1), axis=1).squeeze()
|
|
||||||
batch_idx = paddle.arange(batch_size, dtype=last_valid_idx.dtype)
|
|
||||||
gather_last_valid_idx = paddle.stack(
|
|
||||||
[batch_idx, last_valid_idx], axis=-1)
|
|
||||||
last_tags = tags.gather_nd(gather_last_valid_idx)
|
|
||||||
|
|
||||||
# add the transition from BOS to the first tags for each batch
|
|
||||||
# t_scores = self.transitions[self.BOS_TAG_ID, first_tags]
|
|
||||||
t_scores = self.transitions[self.BOS_TAG_ID].gather(first_tags)
|
|
||||||
|
|
||||||
# add the [unary] emission scores for the first tags for each batch
|
|
||||||
# for all batches, the first word, see the correspondent emissions
|
|
||||||
# for the first tags (which is a list of ids):
|
|
||||||
# emissions[:, 0, [tag_1, tag_2, ..., tag_nblabels]]
|
|
||||||
# e_scores = emissions[:, 0].gather(1, first_tags.unsqueeze(1)).squeeze()
|
|
||||||
gather_first_tags_idx = paddle.stack([batch_idx, first_tags], axis=-1)
|
|
||||||
e_scores = emissions[:, 0].gather_nd(gather_first_tags_idx)
|
|
||||||
|
|
||||||
# the scores for a word is just the sum of both scores
|
|
||||||
scores += e_scores + t_scores
|
|
||||||
|
|
||||||
# now lets do this for each remaining word
|
|
||||||
for i in range(1, seq_length):
|
|
||||||
|
|
||||||
# we could: iterate over batches, check if we reached a mask symbol
|
|
||||||
# and stop the iteration, but vecotrizing is faster due to gpu,
|
|
||||||
# so instead we perform an element-wise multiplication
|
|
||||||
is_valid = mask[:, i]
|
|
||||||
|
|
||||||
previous_tags = tags[:, i - 1]
|
|
||||||
current_tags = tags[:, i]
|
|
||||||
|
|
||||||
# calculate emission and transition scores as we did before
|
|
||||||
# e_scores = emissions[:, i].gather(1, current_tags.unsqueeze(1)).squeeze()
|
|
||||||
gather_current_tags_idx = paddle.stack(
|
|
||||||
[batch_idx, current_tags], axis=-1)
|
|
||||||
e_scores = emissions[:, i].gather_nd(gather_current_tags_idx)
|
|
||||||
# t_scores = self.transitions[previous_tags, current_tags]
|
|
||||||
gather_transitions_idx = paddle.stack(
|
|
||||||
[previous_tags, current_tags], axis=-1)
|
|
||||||
t_scores = self.transitions.gather_nd(gather_transitions_idx)
|
|
||||||
|
|
||||||
# apply the mask
|
|
||||||
e_scores = e_scores * is_valid
|
|
||||||
t_scores = t_scores * is_valid
|
|
||||||
|
|
||||||
scores += e_scores + t_scores
|
|
||||||
|
|
||||||
# add the transition from the end tag to the EOS tag for each batch
|
|
||||||
# scores += self.transitions[last_tags, self.EOS_TAG_ID]
|
|
||||||
scores += self.transitions.gather(last_tags)[:, self.EOS_TAG_ID]
|
|
||||||
|
|
||||||
return scores
|
|
||||||
|
|
||||||
def _compute_log_partition(self, emissions, mask):
|
|
||||||
"""Compute the partition function in log-space using the forward-algorithm.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
emissions (paddle.Tensor): (batch_size, seq_len, nb_labels)
|
|
||||||
mask (Paddle.FloatTensor): (batch_size, seq_len)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
paddle.Tensor: the partition scores for each batch.
|
|
||||||
Shape of (batch_size,)
|
|
||||||
"""
|
|
||||||
batch_size, seq_length, nb_labels = emissions.shape
|
|
||||||
|
|
||||||
# in the first iteration, BOS will have all the scores
|
|
||||||
alphas = self.transitions[self.BOS_TAG_ID, :].unsqueeze(
|
|
||||||
0) + emissions[:, 0]
|
|
||||||
|
|
||||||
for i in range(1, seq_length):
|
|
||||||
# (bs, nb_labels) -> (bs, 1, nb_labels)
|
|
||||||
e_scores = emissions[:, i].unsqueeze(1)
|
|
||||||
|
|
||||||
# (nb_labels, nb_labels) -> (bs, nb_labels, nb_labels)
|
|
||||||
t_scores = self.transitions.unsqueeze(0)
|
|
||||||
|
|
||||||
# (bs, nb_labels) -> (bs, nb_labels, 1)
|
|
||||||
a_scores = alphas.unsqueeze(2)
|
|
||||||
|
|
||||||
scores = e_scores + t_scores + a_scores
|
|
||||||
new_alphas = paddle.logsumexp(scores, axis=1)
|
|
||||||
|
|
||||||
# set alphas if the mask is valid, otherwise keep the current values
|
|
||||||
is_valid = mask[:, i].unsqueeze(-1)
|
|
||||||
alphas = is_valid * new_alphas + (1 - is_valid) * alphas
|
|
||||||
|
|
||||||
# add the scores for the final transition
|
|
||||||
last_transition = self.transitions[:, self.EOS_TAG_ID]
|
|
||||||
end_scores = alphas + last_transition.unsqueeze(0)
|
|
||||||
|
|
||||||
# return a *log* of sums of exps
|
|
||||||
return paddle.logsumexp(end_scores, axis=1)
|
|
||||||
|
|
||||||
def _viterbi_decode(self, emissions, mask):
|
|
||||||
"""Compute the viterbi algorithm to find the most probable sequence of labels
|
|
||||||
given a sequence of emissions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
emissions (paddle.Tensor): (batch_size, seq_len, nb_labels)
|
|
||||||
mask (Paddle.FloatTensor): (batch_size, seq_len)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
paddle.Tensor: the viterbi score for the for each batch.
|
|
||||||
Shape of (batch_size,)
|
|
||||||
list of lists of ints: the best viterbi sequence of labels for each batch
|
|
||||||
"""
|
|
||||||
batch_size, seq_length, nb_labels = emissions.shape
|
|
||||||
|
|
||||||
# in the first iteration, BOS will have all the scores and then, the max
|
|
||||||
alphas = self.transitions[self.BOS_TAG_ID, :].unsqueeze(
|
|
||||||
0) + emissions[:, 0]
|
|
||||||
|
|
||||||
backpointers = []
|
|
||||||
|
|
||||||
for i in range(1, seq_length):
|
|
||||||
# (bs, nb_labels) -> (bs, 1, nb_labels)
|
|
||||||
e_scores = emissions[:, i].unsqueeze(1)
|
|
||||||
|
|
||||||
# (nb_labels, nb_labels) -> (bs, nb_labels, nb_labels)
|
|
||||||
t_scores = self.transitions.unsqueeze(0)
|
|
||||||
|
|
||||||
# (bs, nb_labels) -> (bs, nb_labels, 1)
|
|
||||||
a_scores = alphas.unsqueeze(2)
|
|
||||||
|
|
||||||
# combine current scores with previous alphas
|
|
||||||
scores = e_scores + t_scores + a_scores
|
|
||||||
|
|
||||||
# so far is exactly like the forward algorithm,
|
|
||||||
# but now, instead of calculating the logsumexp,
|
|
||||||
# we will find the highest score and the tag associated with it
|
|
||||||
# max_scores, max_score_tags = paddle.max(scores, axis=1)
|
|
||||||
max_scores = paddle.max(scores, axis=1)
|
|
||||||
max_score_tags = paddle.argmax(scores, axis=1)
|
|
||||||
|
|
||||||
# set alphas if the mask is valid, otherwise keep the current values
|
|
||||||
is_valid = mask[:, i].unsqueeze(-1)
|
|
||||||
alphas = is_valid * max_scores + (1 - is_valid) * alphas
|
|
||||||
|
|
||||||
# add the max_score_tags for our list of backpointers
|
|
||||||
# max_scores has shape (batch_size, nb_labels) so we transpose it to
|
|
||||||
# be compatible with our previous loopy version of viterbi
|
|
||||||
backpointers.append(max_score_tags.t())
|
|
||||||
|
|
||||||
# add the scores for the final transition
|
|
||||||
last_transition = self.transitions[:, self.EOS_TAG_ID]
|
|
||||||
end_scores = alphas + last_transition.unsqueeze(0)
|
|
||||||
|
|
||||||
# get the final most probable score and the final most probable tag
|
|
||||||
# max_final_scores, max_final_tags = paddle.max(end_scores, axis=1)
|
|
||||||
max_final_scores = paddle.max(end_scores, axis=1)
|
|
||||||
max_final_tags = paddle.argmax(end_scores, axis=1)
|
|
||||||
|
|
||||||
# find the best sequence of labels for each sample in the batch
|
|
||||||
best_sequences = []
|
|
||||||
emission_lengths = mask.int().sum(axis=1)
|
|
||||||
for i in range(batch_size):
|
|
||||||
|
|
||||||
# recover the original sentence length for the i-th sample in the batch
|
|
||||||
sample_length = emission_lengths[i].item()
|
|
||||||
|
|
||||||
# recover the max tag for the last timestep
|
|
||||||
sample_final_tag = max_final_tags[i].item()
|
|
||||||
|
|
||||||
# limit the backpointers until the last but one
|
|
||||||
# since the last corresponds to the sample_final_tag
|
|
||||||
sample_backpointers = backpointers[:sample_length - 1]
|
|
||||||
|
|
||||||
# follow the backpointers to build the sequence of labels
|
|
||||||
sample_path = self._find_best_path(i, sample_final_tag,
|
|
||||||
sample_backpointers)
|
|
||||||
|
|
||||||
# add this path to the list of best sequences
|
|
||||||
best_sequences.append(sample_path)
|
|
||||||
|
|
||||||
return max_final_scores, best_sequences
|
|
||||||
|
|
||||||
def _find_best_path(self, sample_id, best_tag, backpointers):
|
|
||||||
"""Auxiliary function to find the best path sequence for a specific sample.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sample_id (int): sample index in the range [0, batch_size)
|
|
||||||
best_tag (int): tag which maximizes the final score
|
|
||||||
backpointers (list of lists of tensors): list of pointers with
|
|
||||||
shape (seq_len_i-1, nb_labels, batch_size) where seq_len_i
|
|
||||||
represents the length of the ith sample in the batch
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list of ints: a list of tag indexes representing the bast path
|
|
||||||
"""
|
|
||||||
# add the final best_tag to our best path
|
|
||||||
best_path = [best_tag]
|
|
||||||
|
|
||||||
# traverse the backpointers in backwards
|
|
||||||
for backpointers_t in reversed(backpointers):
|
|
||||||
|
|
||||||
# recover the best_tag at this timestep
|
|
||||||
best_tag = backpointers_t[best_tag][sample_id].item()
|
|
||||||
|
|
||||||
# append to the beginning of the list so we don't need to reverse it later
|
|
||||||
best_path.insert(0, best_tag)
|
|
||||||
|
|
||||||
return best_path
|
|
||||||
@ -1,274 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import paddle
|
|
||||||
from paddle import nn
|
|
||||||
from paddle.nn import functional as F
|
|
||||||
from typeguard import check_argument_types
|
|
||||||
|
|
||||||
from deepspeech.modules.loss import CTCLoss
|
|
||||||
from deepspeech.utils import ctc_utils
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
try:
|
|
||||||
from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch # noqa: F401
|
|
||||||
from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder # noqa: F401
|
|
||||||
from deepspeech.decoders.swig_wrapper import Scorer # noqa: F401
|
|
||||||
except Exception as e:
|
|
||||||
logger.info("ctcdecoder not installed!")
|
|
||||||
|
|
||||||
__all__ = ['CTCDecoder']
|
|
||||||
|
|
||||||
|
|
||||||
class CTCDecoder(nn.Layer):
|
|
||||||
def __init__(self,
|
|
||||||
odim,
|
|
||||||
enc_n_units,
|
|
||||||
blank_id=0,
|
|
||||||
dropout_rate: float=0.0,
|
|
||||||
reduction: bool=True,
|
|
||||||
batch_average: bool=True):
|
|
||||||
"""CTC decoder
|
|
||||||
|
|
||||||
Args:
|
|
||||||
odim ([int]): text vocabulary size
|
|
||||||
enc_n_units ([int]): encoder output dimention
|
|
||||||
dropout_rate (float): dropout rate (0.0 ~ 1.0)
|
|
||||||
reduction (bool): reduce the CTC loss into a scalar, True for 'sum' or 'none'
|
|
||||||
batch_average (bool): do batch dim wise average.
|
|
||||||
grad_norm_type (str): one of 'instance', 'batchsize', 'frame', None.
|
|
||||||
"""
|
|
||||||
assert check_argument_types()
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.blank_id = blank_id
|
|
||||||
self.odim = odim
|
|
||||||
self.dropout_rate = dropout_rate
|
|
||||||
self.ctc_lo = nn.Linear(enc_n_units, self.odim)
|
|
||||||
reduction_type = "sum" if reduction else "none"
|
|
||||||
self.criterion = CTCLoss(
|
|
||||||
blank=self.blank_id,
|
|
||||||
reduction=reduction_type,
|
|
||||||
batch_average=batch_average)
|
|
||||||
|
|
||||||
# CTCDecoder LM Score handle
|
|
||||||
self._ext_scorer = None
|
|
||||||
|
|
||||||
def forward(self, hs_pad, hlens, ys_pad, ys_lens):
|
|
||||||
"""Calculate CTC loss.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hs_pad (Tensor): batch of padded hidden state sequences (B, Tmax, D)
|
|
||||||
hlens (Tensor): batch of lengths of hidden state sequences (B)
|
|
||||||
ys_pad (Tenosr): batch of padded character id sequence tensor (B, Lmax)
|
|
||||||
ys_lens (Tensor): batch of lengths of character sequence (B)
|
|
||||||
Returns:
|
|
||||||
loss (Tenosr): ctc loss value, scalar.
|
|
||||||
"""
|
|
||||||
logits = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
|
|
||||||
loss = self.criterion(logits, ys_pad, hlens, ys_lens)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def softmax(self, eouts: paddle.Tensor, temperature: float=1.0):
|
|
||||||
"""Get CTC probabilities.
|
|
||||||
Args:
|
|
||||||
eouts (FloatTensor): `[B, T, enc_units]`
|
|
||||||
Returns:
|
|
||||||
probs (FloatTensor): `[B, T, odim]`
|
|
||||||
"""
|
|
||||||
self.probs = F.softmax(self.ctc_lo(eouts) / temperature, axis=2)
|
|
||||||
return self.probs
|
|
||||||
|
|
||||||
def log_softmax(self, hs_pad: paddle.Tensor,
|
|
||||||
temperature: float=1.0) -> paddle.Tensor:
|
|
||||||
"""log_softmax of frame activations
|
|
||||||
Args:
|
|
||||||
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
|
|
||||||
Returns:
|
|
||||||
paddle.Tensor: log softmax applied 3d tensor (B, Tmax, odim)
|
|
||||||
"""
|
|
||||||
return F.log_softmax(self.ctc_lo(hs_pad) / temperature, axis=2)
|
|
||||||
|
|
||||||
def argmax(self, hs_pad: paddle.Tensor) -> paddle.Tensor:
|
|
||||||
"""argmax of frame activations
|
|
||||||
Args:
|
|
||||||
paddle.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
|
|
||||||
Returns:
|
|
||||||
paddle.Tensor: argmax applied 2d tensor (B, Tmax)
|
|
||||||
"""
|
|
||||||
return paddle.argmax(self.ctc_lo(hs_pad), dim=2)
|
|
||||||
|
|
||||||
def forced_align(self,
|
|
||||||
ctc_probs: paddle.Tensor,
|
|
||||||
y: paddle.Tensor,
|
|
||||||
blank_id=0) -> list:
|
|
||||||
"""ctc forced alignment.
|
|
||||||
Args:
|
|
||||||
ctc_probs (paddle.Tensor): hidden state sequence, 2d tensor (T, D)
|
|
||||||
y (paddle.Tensor): label id sequence tensor, 1d tensor (L)
|
|
||||||
blank_id (int): blank symbol index
|
|
||||||
Returns:
|
|
||||||
paddle.Tensor: best alignment result, (T).
|
|
||||||
"""
|
|
||||||
return ctc_utils.forced_align(ctc_probs, y, blank_id)
|
|
||||||
|
|
||||||
def _decode_batch_greedy(self, probs_split, vocab_list):
|
|
||||||
"""Decode by best path for a batch of probs matrix input.
|
|
||||||
:param probs_split: List of 2-D probability matrix, and each consists
|
|
||||||
of prob vectors for one speech utterancce.
|
|
||||||
:param probs_split: List of matrix
|
|
||||||
:param vocab_list: List of tokens in the vocabulary, for decoding.
|
|
||||||
:type vocab_list: list
|
|
||||||
:return: List of transcription texts.
|
|
||||||
:rtype: List of str
|
|
||||||
"""
|
|
||||||
results = []
|
|
||||||
for i, probs in enumerate(probs_split):
|
|
||||||
output_transcription = ctc_greedy_decoder(
|
|
||||||
probs_seq=probs, vocabulary=vocab_list, blank_id=self.blank_id)
|
|
||||||
results.append(output_transcription)
|
|
||||||
return results
|
|
||||||
|
|
||||||
def _init_ext_scorer(self, beam_alpha, beam_beta, language_model_path,
|
|
||||||
vocab_list):
|
|
||||||
"""Initialize the external scorer.
|
|
||||||
:param beam_alpha: Parameter associated with language model.
|
|
||||||
:type beam_alpha: float
|
|
||||||
:param beam_beta: Parameter associated with word count.
|
|
||||||
:type beam_beta: float
|
|
||||||
:param language_model_path: Filepath for language model. If it is
|
|
||||||
empty, the external scorer will be set to
|
|
||||||
None, and the decoding method will be pure
|
|
||||||
beam search without scorer.
|
|
||||||
:type language_model_path: str|None
|
|
||||||
:param vocab_list: List of tokens in the vocabulary, for decoding.
|
|
||||||
:type vocab_list: list
|
|
||||||
"""
|
|
||||||
# init once
|
|
||||||
if self._ext_scorer is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
if language_model_path != '':
|
|
||||||
logger.info("begin to initialize the external scorer "
|
|
||||||
"for decoding")
|
|
||||||
self._ext_scorer = Scorer(beam_alpha, beam_beta,
|
|
||||||
language_model_path, vocab_list)
|
|
||||||
lm_char_based = self._ext_scorer.is_character_based()
|
|
||||||
lm_max_order = self._ext_scorer.get_max_order()
|
|
||||||
lm_dict_size = self._ext_scorer.get_dict_size()
|
|
||||||
logger.info("language model: "
|
|
||||||
"is_character_based = %d," % lm_char_based +
|
|
||||||
" max_order = %d," % lm_max_order + " dict_size = %d" %
|
|
||||||
lm_dict_size)
|
|
||||||
logger.info("end initializing scorer")
|
|
||||||
else:
|
|
||||||
self._ext_scorer = None
|
|
||||||
logger.info("no language model provided, "
|
|
||||||
"decoding by pure beam search without scorer.")
|
|
||||||
|
|
||||||
def _decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta,
|
|
||||||
beam_size, cutoff_prob, cutoff_top_n,
|
|
||||||
vocab_list, num_processes):
|
|
||||||
"""Decode by beam search for a batch of probs matrix input.
|
|
||||||
:param probs_split: List of 2-D probability matrix, and each consists
|
|
||||||
of prob vectors for one speech utterancce.
|
|
||||||
:param probs_split: List of matrix
|
|
||||||
:param beam_alpha: Parameter associated with language model.
|
|
||||||
:type beam_alpha: float
|
|
||||||
:param beam_beta: Parameter associated with word count.
|
|
||||||
:type beam_beta: float
|
|
||||||
:param beam_size: Width for Beam search.
|
|
||||||
:type beam_size: int
|
|
||||||
:param cutoff_prob: Cutoff probability in pruning,
|
|
||||||
default 1.0, no pruning.
|
|
||||||
:type cutoff_prob: float
|
|
||||||
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
|
|
||||||
characters with highest probs in vocabulary will be
|
|
||||||
used in beam search, default 40.
|
|
||||||
:type cutoff_top_n: int
|
|
||||||
:param vocab_list: List of tokens in the vocabulary, for decoding.
|
|
||||||
:type vocab_list: list
|
|
||||||
:param num_processes: Number of processes (CPU) for decoder.
|
|
||||||
:type num_processes: int
|
|
||||||
:return: List of transcription texts.
|
|
||||||
:rtype: List of str
|
|
||||||
"""
|
|
||||||
if self._ext_scorer is not None:
|
|
||||||
self._ext_scorer.reset_params(beam_alpha, beam_beta)
|
|
||||||
|
|
||||||
# beam search decode
|
|
||||||
num_processes = min(num_processes, len(probs_split))
|
|
||||||
beam_search_results = ctc_beam_search_decoder_batch(
|
|
||||||
probs_split=probs_split,
|
|
||||||
vocabulary=vocab_list,
|
|
||||||
beam_size=beam_size,
|
|
||||||
num_processes=num_processes,
|
|
||||||
ext_scoring_func=self._ext_scorer,
|
|
||||||
cutoff_prob=cutoff_prob,
|
|
||||||
cutoff_top_n=cutoff_top_n,
|
|
||||||
blank_id=self.blank_id)
|
|
||||||
|
|
||||||
results = [result[0][1] for result in beam_search_results]
|
|
||||||
return results
|
|
||||||
|
|
||||||
def init_decode(self, beam_alpha, beam_beta, lang_model_path, vocab_list,
|
|
||||||
decoding_method):
|
|
||||||
|
|
||||||
if decoding_method == "ctc_beam_search":
|
|
||||||
self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path,
|
|
||||||
vocab_list)
|
|
||||||
|
|
||||||
def decode_probs(self, probs, logits_lens, vocab_list, decoding_method,
|
|
||||||
lang_model_path, beam_alpha, beam_beta, beam_size,
|
|
||||||
cutoff_prob, cutoff_top_n, num_processes):
|
|
||||||
"""ctc decoding with probs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
probs (Tenosr): activation after softmax
|
|
||||||
logits_lens (Tenosr): audio output lens
|
|
||||||
vocab_list ([type]): [description]
|
|
||||||
decoding_method ([type]): [description]
|
|
||||||
lang_model_path ([type]): [description]
|
|
||||||
beam_alpha ([type]): [description]
|
|
||||||
beam_beta ([type]): [description]
|
|
||||||
beam_size ([type]): [description]
|
|
||||||
cutoff_prob ([type]): [description]
|
|
||||||
cutoff_top_n ([type]): [description]
|
|
||||||
num_processes ([type]): [description]
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: when decoding_method not support.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[str]: transcripts.
|
|
||||||
"""
|
|
||||||
|
|
||||||
probs_split = [probs[i, :l, :] for i, l in enumerate(logits_lens)]
|
|
||||||
if decoding_method == "ctc_greedy":
|
|
||||||
result_transcripts = self._decode_batch_greedy(
|
|
||||||
probs_split=probs_split, vocab_list=vocab_list)
|
|
||||||
elif decoding_method == "ctc_beam_search":
|
|
||||||
result_transcripts = self._decode_batch_beam_search(
|
|
||||||
probs_split=probs_split,
|
|
||||||
beam_alpha=beam_alpha,
|
|
||||||
beam_beta=beam_beta,
|
|
||||||
beam_size=beam_size,
|
|
||||||
cutoff_prob=cutoff_prob,
|
|
||||||
cutoff_top_n=cutoff_top_n,
|
|
||||||
vocab_list=vocab_list,
|
|
||||||
num_processes=num_processes)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Not support: {decoding_method}")
|
|
||||||
return result_transcripts
|
|
||||||
@ -1,182 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Decoder definition."""
|
|
||||||
from typing import List
|
|
||||||
from typing import Optional
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import paddle
|
|
||||||
from paddle import nn
|
|
||||||
from typeguard import check_argument_types
|
|
||||||
|
|
||||||
from deepspeech.modules.attention import MultiHeadedAttention
|
|
||||||
from deepspeech.modules.decoder_layer import DecoderLayer
|
|
||||||
from deepspeech.modules.embedding import PositionalEncoding
|
|
||||||
from deepspeech.modules.mask import make_non_pad_mask
|
|
||||||
from deepspeech.modules.mask import subsequent_mask
|
|
||||||
from deepspeech.modules.positionwise_feed_forward import PositionwiseFeedForward
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
__all__ = ["TransformerDecoder"]
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerDecoder(nn.Layer):
|
|
||||||
"""Base class of Transfomer decoder module.
|
|
||||||
Args:
|
|
||||||
vocab_size: output dim
|
|
||||||
encoder_output_size: dimension of attention
|
|
||||||
attention_heads: the number of heads of multi head attention
|
|
||||||
linear_units: the hidden units number of position-wise feedforward
|
|
||||||
num_blocks: the number of decoder blocks
|
|
||||||
dropout_rate: dropout rate
|
|
||||||
self_attention_dropout_rate: dropout rate for attention
|
|
||||||
input_layer: input layer type, `embed`
|
|
||||||
use_output_layer: whether to use output layer
|
|
||||||
pos_enc_class: PositionalEncoding module
|
|
||||||
normalize_before:
|
|
||||||
True: use layer_norm before each sub-block of a layer.
|
|
||||||
False: use layer_norm after each sub-block of a layer.
|
|
||||||
concat_after: whether to concat attention layer's input and output
|
|
||||||
True: x -> x + linear(concat(x, att(x)))
|
|
||||||
False: x -> x + att(x)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vocab_size: int,
|
|
||||||
encoder_output_size: int,
|
|
||||||
attention_heads: int=4,
|
|
||||||
linear_units: int=2048,
|
|
||||||
num_blocks: int=6,
|
|
||||||
dropout_rate: float=0.1,
|
|
||||||
positional_dropout_rate: float=0.1,
|
|
||||||
self_attention_dropout_rate: float=0.0,
|
|
||||||
src_attention_dropout_rate: float=0.0,
|
|
||||||
input_layer: str="embed",
|
|
||||||
use_output_layer: bool=True,
|
|
||||||
normalize_before: bool=True,
|
|
||||||
concat_after: bool=False, ):
|
|
||||||
|
|
||||||
assert check_argument_types()
|
|
||||||
super().__init__()
|
|
||||||
attention_dim = encoder_output_size
|
|
||||||
|
|
||||||
if input_layer == "embed":
|
|
||||||
self.embed = nn.Sequential(
|
|
||||||
nn.Embedding(vocab_size, attention_dim),
|
|
||||||
PositionalEncoding(attention_dim, positional_dropout_rate), )
|
|
||||||
else:
|
|
||||||
raise ValueError(f"only 'embed' is supported: {input_layer}")
|
|
||||||
|
|
||||||
self.normalize_before = normalize_before
|
|
||||||
self.after_norm = nn.LayerNorm(attention_dim, epsilon=1e-12)
|
|
||||||
self.use_output_layer = use_output_layer
|
|
||||||
self.output_layer = nn.Linear(attention_dim, vocab_size)
|
|
||||||
|
|
||||||
self.decoders = nn.LayerList([
|
|
||||||
DecoderLayer(
|
|
||||||
size=attention_dim,
|
|
||||||
self_attn=MultiHeadedAttention(attention_heads, attention_dim,
|
|
||||||
self_attention_dropout_rate),
|
|
||||||
src_attn=MultiHeadedAttention(attention_heads, attention_dim,
|
|
||||||
src_attention_dropout_rate),
|
|
||||||
feed_forward=PositionwiseFeedForward(
|
|
||||||
attention_dim, linear_units, dropout_rate),
|
|
||||||
dropout_rate=dropout_rate,
|
|
||||||
normalize_before=normalize_before,
|
|
||||||
concat_after=concat_after, ) for _ in range(num_blocks)
|
|
||||||
])
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
memory: paddle.Tensor,
|
|
||||||
memory_mask: paddle.Tensor,
|
|
||||||
ys_in_pad: paddle.Tensor,
|
|
||||||
ys_in_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
|
||||||
"""Forward decoder.
|
|
||||||
Args:
|
|
||||||
memory: encoded memory, float32 (batch, maxlen_in, feat)
|
|
||||||
memory_mask: encoder memory mask, (batch, 1, maxlen_in)
|
|
||||||
ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
|
|
||||||
ys_in_lens: input lengths of this batch (batch)
|
|
||||||
Returns:
|
|
||||||
(tuple): tuple containing:
|
|
||||||
x: decoded token score before softmax (batch, maxlen_out, vocab_size)
|
|
||||||
if use_output_layer is True,
|
|
||||||
olens: (batch, )
|
|
||||||
"""
|
|
||||||
tgt = ys_in_pad
|
|
||||||
# tgt_mask: (B, 1, L)
|
|
||||||
tgt_mask = (make_non_pad_mask(ys_in_lens).unsqueeze(1))
|
|
||||||
# m: (1, L, L)
|
|
||||||
m = subsequent_mask(tgt_mask.size(-1)).unsqueeze(0)
|
|
||||||
# tgt_mask: (B, L, L)
|
|
||||||
# TODO(Hui Zhang): not support & for tensor
|
|
||||||
# tgt_mask = tgt_mask & m
|
|
||||||
tgt_mask = tgt_mask.logical_and(m)
|
|
||||||
|
|
||||||
x, _ = self.embed(tgt)
|
|
||||||
for layer in self.decoders:
|
|
||||||
x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory,
|
|
||||||
memory_mask)
|
|
||||||
if self.normalize_before:
|
|
||||||
x = self.after_norm(x)
|
|
||||||
if self.use_output_layer:
|
|
||||||
x = self.output_layer(x)
|
|
||||||
|
|
||||||
# TODO(Hui Zhang): reduce_sum not support bool type
|
|
||||||
# olens = tgt_mask.sum(1)
|
|
||||||
olens = tgt_mask.astype(paddle.int).sum(1)
|
|
||||||
return x, olens
|
|
||||||
|
|
||||||
def forward_one_step(
|
|
||||||
self,
|
|
||||||
memory: paddle.Tensor,
|
|
||||||
memory_mask: paddle.Tensor,
|
|
||||||
tgt: paddle.Tensor,
|
|
||||||
tgt_mask: paddle.Tensor,
|
|
||||||
cache: Optional[List[paddle.Tensor]]=None,
|
|
||||||
) -> Tuple[paddle.Tensor, List[paddle.Tensor]]:
|
|
||||||
"""Forward one step.
|
|
||||||
This is only used for decoding.
|
|
||||||
Args:
|
|
||||||
memory: encoded memory, float32 (batch, maxlen_in, feat)
|
|
||||||
memory_mask: encoded memory mask, (batch, 1, maxlen_in)
|
|
||||||
tgt: input token ids, int64 (batch, maxlen_out)
|
|
||||||
tgt_mask: input token mask, (batch, maxlen_out, maxlen_out)
|
|
||||||
dtype=paddle.bool
|
|
||||||
cache: cached output list of (batch, max_time_out-1, size)
|
|
||||||
Returns:
|
|
||||||
y, cache: NN output value and cache per `self.decoders`.
|
|
||||||
y.shape` is (batch, token)
|
|
||||||
"""
|
|
||||||
x, _ = self.embed(tgt)
|
|
||||||
new_cache = []
|
|
||||||
for i, decoder in enumerate(self.decoders):
|
|
||||||
if cache is None:
|
|
||||||
c = None
|
|
||||||
else:
|
|
||||||
c = cache[i]
|
|
||||||
x, tgt_mask, memory, memory_mask = decoder(
|
|
||||||
x, tgt_mask, memory, memory_mask, cache=c)
|
|
||||||
new_cache.append(x)
|
|
||||||
if self.normalize_before:
|
|
||||||
y = self.after_norm(x[:, -1])
|
|
||||||
else:
|
|
||||||
y = x[:, -1]
|
|
||||||
if self.use_output_layer:
|
|
||||||
y = paddle.log_softmax(self.output_layer(y), axis=-1)
|
|
||||||
return y, new_cache
|
|
||||||
@ -1,151 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Decoder self-attention layer definition."""
|
|
||||||
from typing import Optional
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import paddle
|
|
||||||
from paddle import nn
|
|
||||||
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
__all__ = ["DecoderLayer"]
|
|
||||||
|
|
||||||
|
|
||||||
class DecoderLayer(nn.Layer):
|
|
||||||
"""Single decoder layer module.
|
|
||||||
Args:
|
|
||||||
size (int): Input dimension.
|
|
||||||
self_attn (nn.Layer): Self-attention module instance.
|
|
||||||
`MultiHeadedAttention` instance can be used as the argument.
|
|
||||||
src_attn (nn.Layer): Self-attention module instance.
|
|
||||||
`MultiHeadedAttention` instance can be used as the argument.
|
|
||||||
feed_forward (nn.Layer): Feed-forward module instance.
|
|
||||||
`PositionwiseFeedForward` instance can be used as the argument.
|
|
||||||
dropout_rate (float): Dropout rate.
|
|
||||||
normalize_before (bool):
|
|
||||||
True: use layer_norm before each sub-block.
|
|
||||||
False: to use layer_norm after each sub-block.
|
|
||||||
concat_after (bool): Whether to concat attention layer's input
|
|
||||||
and output.
|
|
||||||
True: x -> x + linear(concat(x, att(x)))
|
|
||||||
False: x -> x + att(x)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
size: int,
|
|
||||||
self_attn: nn.Layer,
|
|
||||||
src_attn: nn.Layer,
|
|
||||||
feed_forward: nn.Layer,
|
|
||||||
dropout_rate: float,
|
|
||||||
normalize_before: bool=True,
|
|
||||||
concat_after: bool=False, ):
|
|
||||||
"""Construct an DecoderLayer object."""
|
|
||||||
super().__init__()
|
|
||||||
self.size = size
|
|
||||||
self.self_attn = self_attn
|
|
||||||
self.src_attn = src_attn
|
|
||||||
self.feed_forward = feed_forward
|
|
||||||
self.norm1 = nn.LayerNorm(size, epsilon=1e-12)
|
|
||||||
self.norm2 = nn.LayerNorm(size, epsilon=1e-12)
|
|
||||||
self.norm3 = nn.LayerNorm(size, epsilon=1e-12)
|
|
||||||
self.dropout = nn.Dropout(dropout_rate)
|
|
||||||
self.normalize_before = normalize_before
|
|
||||||
self.concat_after = concat_after
|
|
||||||
self.concat_linear1 = nn.Linear(size + size, size)
|
|
||||||
self.concat_linear2 = nn.Linear(size + size, size)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
tgt: paddle.Tensor,
|
|
||||||
tgt_mask: paddle.Tensor,
|
|
||||||
memory: paddle.Tensor,
|
|
||||||
memory_mask: paddle.Tensor,
|
|
||||||
cache: Optional[paddle.Tensor]=None
|
|
||||||
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
|
||||||
"""Compute decoded features.
|
|
||||||
Args:
|
|
||||||
tgt (paddle.Tensor): Input tensor (#batch, maxlen_out, size).
|
|
||||||
tgt_mask (paddle.Tensor): Mask for input tensor
|
|
||||||
(#batch, maxlen_out).
|
|
||||||
memory (paddle.Tensor): Encoded memory
|
|
||||||
(#batch, maxlen_in, size).
|
|
||||||
memory_mask (paddle.Tensor): Encoded memory mask
|
|
||||||
(#batch, maxlen_in).
|
|
||||||
cache (paddle.Tensor): cached tensors.
|
|
||||||
(#batch, maxlen_out - 1, size).
|
|
||||||
Returns:
|
|
||||||
paddle.Tensor: Output tensor (#batch, maxlen_out, size).
|
|
||||||
paddle.Tensor: Mask for output tensor (#batch, maxlen_out).
|
|
||||||
paddle.Tensor: Encoded memory (#batch, maxlen_in, size).
|
|
||||||
paddle.Tensor: Encoded memory mask (#batch, maxlen_in).
|
|
||||||
"""
|
|
||||||
residual = tgt
|
|
||||||
if self.normalize_before:
|
|
||||||
tgt = self.norm1(tgt)
|
|
||||||
|
|
||||||
if cache is None:
|
|
||||||
tgt_q = tgt
|
|
||||||
tgt_q_mask = tgt_mask
|
|
||||||
else:
|
|
||||||
# compute only the last frame query keeping dim: max_time_out -> 1
|
|
||||||
assert cache.shape == [
|
|
||||||
tgt.shape[0],
|
|
||||||
tgt.shape[1] - 1,
|
|
||||||
self.size,
|
|
||||||
], f"{cache.shape} == {[tgt.shape[0], tgt.shape[1] - 1, self.size]}"
|
|
||||||
tgt_q = tgt[:, -1:, :]
|
|
||||||
residual = residual[:, -1:, :]
|
|
||||||
# TODO(Hui Zhang): slice not support bool type
|
|
||||||
# tgt_q_mask = tgt_mask[:, -1:, :]
|
|
||||||
tgt_q_mask = tgt_mask.cast(paddle.int64)[:, -1:, :].cast(
|
|
||||||
paddle.bool)
|
|
||||||
|
|
||||||
if self.concat_after:
|
|
||||||
tgt_concat = paddle.cat(
|
|
||||||
(tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1)
|
|
||||||
x = residual + self.concat_linear1(tgt_concat)
|
|
||||||
else:
|
|
||||||
x = residual + self.dropout(
|
|
||||||
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
|
|
||||||
if not self.normalize_before:
|
|
||||||
x = self.norm1(x)
|
|
||||||
|
|
||||||
residual = x
|
|
||||||
if self.normalize_before:
|
|
||||||
x = self.norm2(x)
|
|
||||||
if self.concat_after:
|
|
||||||
x_concat = paddle.cat(
|
|
||||||
(x, self.src_attn(x, memory, memory, memory_mask)), dim=-1)
|
|
||||||
x = residual + self.concat_linear2(x_concat)
|
|
||||||
else:
|
|
||||||
x = residual + self.dropout(
|
|
||||||
self.src_attn(x, memory, memory, memory_mask))
|
|
||||||
if not self.normalize_before:
|
|
||||||
x = self.norm2(x)
|
|
||||||
|
|
||||||
residual = x
|
|
||||||
if self.normalize_before:
|
|
||||||
x = self.norm3(x)
|
|
||||||
x = residual + self.dropout(self.feed_forward(x))
|
|
||||||
if not self.normalize_before:
|
|
||||||
x = self.norm3(x)
|
|
||||||
|
|
||||||
if cache is not None:
|
|
||||||
x = paddle.cat([cache, x], dim=1)
|
|
||||||
|
|
||||||
return x, tgt_mask, memory, memory_mask
|
|
||||||
@ -1,453 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Encoder definition."""
|
|
||||||
from typing import List
|
|
||||||
from typing import Optional
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import paddle
|
|
||||||
from paddle import nn
|
|
||||||
from typeguard import check_argument_types
|
|
||||||
|
|
||||||
from deepspeech.modules.activation import get_activation
|
|
||||||
from deepspeech.modules.attention import MultiHeadedAttention
|
|
||||||
from deepspeech.modules.attention import RelPositionMultiHeadedAttention
|
|
||||||
from deepspeech.modules.conformer_convolution import ConvolutionModule
|
|
||||||
from deepspeech.modules.embedding import PositionalEncoding
|
|
||||||
from deepspeech.modules.embedding import RelPositionalEncoding
|
|
||||||
from deepspeech.modules.encoder_layer import ConformerEncoderLayer
|
|
||||||
from deepspeech.modules.encoder_layer import TransformerEncoderLayer
|
|
||||||
from deepspeech.modules.mask import add_optional_chunk_mask
|
|
||||||
from deepspeech.modules.mask import make_non_pad_mask
|
|
||||||
from deepspeech.modules.positionwise_feed_forward import PositionwiseFeedForward
|
|
||||||
from deepspeech.modules.subsampling import Conv2dSubsampling4
|
|
||||||
from deepspeech.modules.subsampling import Conv2dSubsampling6
|
|
||||||
from deepspeech.modules.subsampling import Conv2dSubsampling8
|
|
||||||
from deepspeech.modules.subsampling import LinearNoSubsampling
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
__all__ = ["BaseEncoder", 'TransformerEncoder', "ConformerEncoder"]
|
|
||||||
|
|
||||||
|
|
||||||
class BaseEncoder(nn.Layer):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_size: int,
|
|
||||||
output_size: int=256,
|
|
||||||
attention_heads: int=4,
|
|
||||||
linear_units: int=2048,
|
|
||||||
num_blocks: int=6,
|
|
||||||
dropout_rate: float=0.1,
|
|
||||||
positional_dropout_rate: float=0.1,
|
|
||||||
attention_dropout_rate: float=0.0,
|
|
||||||
input_layer: str="conv2d",
|
|
||||||
pos_enc_layer_type: str="abs_pos",
|
|
||||||
normalize_before: bool=True,
|
|
||||||
concat_after: bool=False,
|
|
||||||
static_chunk_size: int=0,
|
|
||||||
use_dynamic_chunk: bool=False,
|
|
||||||
global_cmvn: paddle.nn.Layer=None,
|
|
||||||
use_dynamic_left_chunk: bool=False, ):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
input_size (int): input dim, d_feature
|
|
||||||
output_size (int): dimension of attention, d_model
|
|
||||||
attention_heads (int): the number of heads of multi head attention
|
|
||||||
linear_units (int): the hidden units number of position-wise feed
|
|
||||||
forward
|
|
||||||
num_blocks (int): the number of encoder blocks
|
|
||||||
dropout_rate (float): dropout rate
|
|
||||||
attention_dropout_rate (float): dropout rate in attention
|
|
||||||
positional_dropout_rate (float): dropout rate after adding
|
|
||||||
positional encoding
|
|
||||||
input_layer (str): input layer type.
|
|
||||||
optional [linear, conv2d, conv2d6, conv2d8]
|
|
||||||
pos_enc_layer_type (str): Encoder positional encoding layer type.
|
|
||||||
opitonal [abs_pos, scaled_abs_pos, rel_pos]
|
|
||||||
normalize_before (bool):
|
|
||||||
True: use layer_norm before each sub-block of a layer.
|
|
||||||
False: use layer_norm after each sub-block of a layer.
|
|
||||||
concat_after (bool): whether to concat attention layer's input
|
|
||||||
and output.
|
|
||||||
True: x -> x + linear(concat(x, att(x)))
|
|
||||||
False: x -> x + att(x)
|
|
||||||
static_chunk_size (int): chunk size for static chunk training and
|
|
||||||
decoding
|
|
||||||
use_dynamic_chunk (bool): whether use dynamic chunk size for
|
|
||||||
training or not, You can only use fixed chunk(chunk_size > 0)
|
|
||||||
or dyanmic chunk size(use_dynamic_chunk = True)
|
|
||||||
global_cmvn (Optional[paddle.nn.Layer]): Optional GlobalCMVN layer
|
|
||||||
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
|
|
||||||
dynamic chunk training
|
|
||||||
"""
|
|
||||||
assert check_argument_types()
|
|
||||||
super().__init__()
|
|
||||||
self._output_size = output_size
|
|
||||||
|
|
||||||
if pos_enc_layer_type == "abs_pos":
|
|
||||||
pos_enc_class = PositionalEncoding
|
|
||||||
elif pos_enc_layer_type == "rel_pos":
|
|
||||||
pos_enc_class = RelPositionalEncoding
|
|
||||||
else:
|
|
||||||
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
|
|
||||||
|
|
||||||
if input_layer == "linear":
|
|
||||||
subsampling_class = LinearNoSubsampling
|
|
||||||
elif input_layer == "conv2d":
|
|
||||||
subsampling_class = Conv2dSubsampling4
|
|
||||||
elif input_layer == "conv2d6":
|
|
||||||
subsampling_class = Conv2dSubsampling6
|
|
||||||
elif input_layer == "conv2d8":
|
|
||||||
subsampling_class = Conv2dSubsampling8
|
|
||||||
else:
|
|
||||||
raise ValueError("unknown input_layer: " + input_layer)
|
|
||||||
|
|
||||||
self.global_cmvn = global_cmvn
|
|
||||||
self.embed = subsampling_class(
|
|
||||||
idim=input_size,
|
|
||||||
odim=output_size,
|
|
||||||
dropout_rate=dropout_rate,
|
|
||||||
pos_enc_class=pos_enc_class(
|
|
||||||
d_model=output_size, dropout_rate=positional_dropout_rate), )
|
|
||||||
|
|
||||||
self.normalize_before = normalize_before
|
|
||||||
self.after_norm = nn.LayerNorm(output_size, epsilon=1e-12)
|
|
||||||
self.static_chunk_size = static_chunk_size
|
|
||||||
self.use_dynamic_chunk = use_dynamic_chunk
|
|
||||||
self.use_dynamic_left_chunk = use_dynamic_left_chunk
|
|
||||||
|
|
||||||
def output_size(self) -> int:
|
|
||||||
return self._output_size
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
xs: paddle.Tensor,
|
|
||||||
xs_lens: paddle.Tensor,
|
|
||||||
decoding_chunk_size: int=0,
|
|
||||||
num_decoding_left_chunks: int=-1,
|
|
||||||
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
|
||||||
"""Embed positions in tensor.
|
|
||||||
Args:
|
|
||||||
xs: padded input tensor (B, L, D)
|
|
||||||
xs_lens: input length (B)
|
|
||||||
decoding_chunk_size: decoding chunk size for dynamic chunk
|
|
||||||
0: default for training, use random dynamic chunk.
|
|
||||||
<0: for decoding, use full chunk.
|
|
||||||
>0: for decoding, use fixed chunk size as set.
|
|
||||||
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
|
||||||
the chunk size is decoding_chunk_size.
|
|
||||||
>=0: use num_decoding_left_chunks
|
|
||||||
<0: use all left chunks
|
|
||||||
Returns:
|
|
||||||
encoder output tensor, lens and mask
|
|
||||||
"""
|
|
||||||
masks = make_non_pad_mask(xs_lens).unsqueeze(1) # (B, 1, L)
|
|
||||||
|
|
||||||
if self.global_cmvn is not None:
|
|
||||||
xs = self.global_cmvn(xs)
|
|
||||||
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
|
|
||||||
xs, pos_emb, masks = self.embed(xs, masks.type_as(xs), offset=0)
|
|
||||||
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
|
|
||||||
masks = masks.astype(paddle.bool)
|
|
||||||
#TODO(Hui Zhang): mask_pad = ~masks
|
|
||||||
mask_pad = masks.logical_not()
|
|
||||||
chunk_masks = add_optional_chunk_mask(
|
|
||||||
xs, masks, self.use_dynamic_chunk, self.use_dynamic_left_chunk,
|
|
||||||
decoding_chunk_size, self.static_chunk_size,
|
|
||||||
num_decoding_left_chunks)
|
|
||||||
for layer in self.encoders:
|
|
||||||
xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
|
||||||
if self.normalize_before:
|
|
||||||
xs = self.after_norm(xs)
|
|
||||||
# Here we assume the mask is not changed in encoder layers, so just
|
|
||||||
# return the masks before encoder layers, and the masks will be used
|
|
||||||
# for cross attention with decoder later
|
|
||||||
return xs, masks
|
|
||||||
|
|
||||||
def forward_chunk(
|
|
||||||
self,
|
|
||||||
xs: paddle.Tensor,
|
|
||||||
offset: int,
|
|
||||||
required_cache_size: int,
|
|
||||||
subsampling_cache: Optional[paddle.Tensor]=None,
|
|
||||||
elayers_output_cache: Optional[List[paddle.Tensor]]=None,
|
|
||||||
conformer_cnn_cache: Optional[List[paddle.Tensor]]=None,
|
|
||||||
) -> Tuple[paddle.Tensor, paddle.Tensor, List[paddle.Tensor], List[
|
|
||||||
paddle.Tensor]]:
|
|
||||||
""" Forward just one chunk
|
|
||||||
Args:
|
|
||||||
xs (paddle.Tensor): chunk input, [B=1, T, D]
|
|
||||||
offset (int): current offset in encoder output time stamp
|
|
||||||
required_cache_size (int): cache size required for next chunk
|
|
||||||
compuation
|
|
||||||
>=0: actual cache size
|
|
||||||
<0: means all history cache is required
|
|
||||||
subsampling_cache (Optional[paddle.Tensor]): subsampling cache
|
|
||||||
elayers_output_cache (Optional[List[paddle.Tensor]]):
|
|
||||||
transformer/conformer encoder layers output cache
|
|
||||||
conformer_cnn_cache (Optional[List[paddle.Tensor]]): conformer
|
|
||||||
cnn cache
|
|
||||||
Returns:
|
|
||||||
paddle.Tensor: output of current input xs
|
|
||||||
paddle.Tensor: subsampling cache required for next chunk computation
|
|
||||||
List[paddle.Tensor]: encoder layers output cache required for next
|
|
||||||
chunk computation
|
|
||||||
List[paddle.Tensor]: conformer cnn cache
|
|
||||||
"""
|
|
||||||
assert xs.size(0) == 1 # batch size must be one
|
|
||||||
# tmp_masks is just for interface compatibility
|
|
||||||
# TODO(Hui Zhang): stride_slice not support bool tensor
|
|
||||||
# tmp_masks = paddle.ones([1, xs.size(1)], dtype=paddle.bool)
|
|
||||||
tmp_masks = paddle.ones([1, xs.size(1)], dtype=paddle.int32)
|
|
||||||
tmp_masks = tmp_masks.unsqueeze(1) #[B=1, C=1, T]
|
|
||||||
|
|
||||||
if self.global_cmvn is not None:
|
|
||||||
xs = self.global_cmvn(xs)
|
|
||||||
|
|
||||||
xs, pos_emb, _ = self.embed(
|
|
||||||
xs, tmp_masks, offset=offset) #xs=(B, T, D), pos_emb=(B=1, T, D)
|
|
||||||
|
|
||||||
if subsampling_cache is not None:
|
|
||||||
cache_size = subsampling_cache.size(1) #T
|
|
||||||
xs = paddle.cat((subsampling_cache, xs), dim=1)
|
|
||||||
else:
|
|
||||||
cache_size = 0
|
|
||||||
|
|
||||||
# only used when using `RelPositionMultiHeadedAttention`
|
|
||||||
pos_emb = self.embed.position_encoding(
|
|
||||||
offset=offset - cache_size, size=xs.size(1))
|
|
||||||
|
|
||||||
if required_cache_size < 0:
|
|
||||||
next_cache_start = 0
|
|
||||||
elif required_cache_size == 0:
|
|
||||||
next_cache_start = xs.size(1)
|
|
||||||
else:
|
|
||||||
next_cache_start = xs.size(1) - required_cache_size
|
|
||||||
r_subsampling_cache = xs[:, next_cache_start:, :]
|
|
||||||
|
|
||||||
# Real mask for transformer/conformer layers
|
|
||||||
masks = paddle.ones([1, xs.size(1)], dtype=paddle.bool)
|
|
||||||
masks = masks.unsqueeze(1) #[B=1, L'=1, T]
|
|
||||||
r_elayers_output_cache = []
|
|
||||||
r_conformer_cnn_cache = []
|
|
||||||
for i, layer in enumerate(self.encoders):
|
|
||||||
attn_cache = None if elayers_output_cache is None else elayers_output_cache[
|
|
||||||
i]
|
|
||||||
cnn_cache = None if conformer_cnn_cache is None else conformer_cnn_cache[
|
|
||||||
i]
|
|
||||||
xs, _, new_cnn_cache = layer(
|
|
||||||
xs,
|
|
||||||
masks,
|
|
||||||
pos_emb,
|
|
||||||
output_cache=attn_cache,
|
|
||||||
cnn_cache=cnn_cache)
|
|
||||||
r_elayers_output_cache.append(xs[:, next_cache_start:, :])
|
|
||||||
r_conformer_cnn_cache.append(new_cnn_cache)
|
|
||||||
if self.normalize_before:
|
|
||||||
xs = self.after_norm(xs)
|
|
||||||
|
|
||||||
return (xs[:, cache_size:, :], r_subsampling_cache,
|
|
||||||
r_elayers_output_cache, r_conformer_cnn_cache)
|
|
||||||
|
|
||||||
def forward_chunk_by_chunk(
|
|
||||||
self,
|
|
||||||
xs: paddle.Tensor,
|
|
||||||
decoding_chunk_size: int,
|
|
||||||
num_decoding_left_chunks: int=-1,
|
|
||||||
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
|
||||||
""" Forward input chunk by chunk with chunk_size like a streaming
|
|
||||||
fashion
|
|
||||||
Here we should pay special attention to computation cache in the
|
|
||||||
streaming style forward chunk by chunk. Three things should be taken
|
|
||||||
into account for computation in the current network:
|
|
||||||
1. transformer/conformer encoder layers output cache
|
|
||||||
2. convolution in conformer
|
|
||||||
3. convolution in subsampling
|
|
||||||
However, we don't implement subsampling cache for:
|
|
||||||
1. We can control subsampling module to output the right result by
|
|
||||||
overlapping input instead of cache left context, even though it
|
|
||||||
wastes some computation, but subsampling only takes a very
|
|
||||||
small fraction of computation in the whole model.
|
|
||||||
2. Typically, there are several covolution layers with subsampling
|
|
||||||
in subsampling module, it is tricky and complicated to do cache
|
|
||||||
with different convolution layers with different subsampling
|
|
||||||
rate.
|
|
||||||
3. Currently, nn.Sequential is used to stack all the convolution
|
|
||||||
layers in subsampling, we need to rewrite it to make it work
|
|
||||||
with cache, which is not prefered.
|
|
||||||
Args:
|
|
||||||
xs (paddle.Tensor): (1, max_len, dim)
|
|
||||||
chunk_size (int): decoding chunk size.
|
|
||||||
num_left_chunks (int): decoding with num left chunks.
|
|
||||||
"""
|
|
||||||
assert decoding_chunk_size > 0
|
|
||||||
# The model is trained by static or dynamic chunk
|
|
||||||
assert self.static_chunk_size > 0 or self.use_dynamic_chunk
|
|
||||||
|
|
||||||
# feature stride and window for `subsampling` module
|
|
||||||
subsampling = self.embed.subsampling_rate
|
|
||||||
context = self.embed.right_context + 1 # Add current frame
|
|
||||||
stride = subsampling * decoding_chunk_size
|
|
||||||
decoding_window = (decoding_chunk_size - 1) * subsampling + context
|
|
||||||
|
|
||||||
num_frames = xs.size(1)
|
|
||||||
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
|
|
||||||
subsampling_cache: Optional[paddle.Tensor] = None
|
|
||||||
elayers_output_cache: Optional[List[paddle.Tensor]] = None
|
|
||||||
conformer_cnn_cache: Optional[List[paddle.Tensor]] = None
|
|
||||||
outputs = []
|
|
||||||
offset = 0
|
|
||||||
# Feed forward overlap input step by step
|
|
||||||
for cur in range(0, num_frames - context + 1, stride):
|
|
||||||
end = min(cur + decoding_window, num_frames)
|
|
||||||
chunk_xs = xs[:, cur:end, :]
|
|
||||||
(y, subsampling_cache, elayers_output_cache,
|
|
||||||
conformer_cnn_cache) = self.forward_chunk(
|
|
||||||
chunk_xs, offset, required_cache_size, subsampling_cache,
|
|
||||||
elayers_output_cache, conformer_cnn_cache)
|
|
||||||
outputs.append(y)
|
|
||||||
offset += y.size(1)
|
|
||||||
ys = paddle.cat(outputs, 1)
|
|
||||||
# fake mask, just for jit script and compatibility with `forward` api
|
|
||||||
masks = paddle.ones([1, ys.size(1)], dtype=paddle.bool)
|
|
||||||
masks = masks.unsqueeze(1)
|
|
||||||
return ys, masks
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoder(BaseEncoder):
|
|
||||||
"""Transformer encoder module."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_size: int,
|
|
||||||
output_size: int=256,
|
|
||||||
attention_heads: int=4,
|
|
||||||
linear_units: int=2048,
|
|
||||||
num_blocks: int=6,
|
|
||||||
dropout_rate: float=0.1,
|
|
||||||
positional_dropout_rate: float=0.1,
|
|
||||||
attention_dropout_rate: float=0.0,
|
|
||||||
input_layer: str="conv2d",
|
|
||||||
pos_enc_layer_type: str="abs_pos",
|
|
||||||
normalize_before: bool=True,
|
|
||||||
concat_after: bool=False,
|
|
||||||
static_chunk_size: int=0,
|
|
||||||
use_dynamic_chunk: bool=False,
|
|
||||||
global_cmvn: nn.Layer=None,
|
|
||||||
use_dynamic_left_chunk: bool=False, ):
|
|
||||||
""" Construct TransformerEncoder
|
|
||||||
See Encoder for the meaning of each parameter.
|
|
||||||
"""
|
|
||||||
assert check_argument_types()
|
|
||||||
super().__init__(input_size, output_size, attention_heads, linear_units,
|
|
||||||
num_blocks, dropout_rate, positional_dropout_rate,
|
|
||||||
attention_dropout_rate, input_layer,
|
|
||||||
pos_enc_layer_type, normalize_before, concat_after,
|
|
||||||
static_chunk_size, use_dynamic_chunk, global_cmvn,
|
|
||||||
use_dynamic_left_chunk)
|
|
||||||
self.encoders = nn.LayerList([
|
|
||||||
TransformerEncoderLayer(
|
|
||||||
size=output_size,
|
|
||||||
self_attn=MultiHeadedAttention(attention_heads, output_size,
|
|
||||||
attention_dropout_rate),
|
|
||||||
feed_forward=PositionwiseFeedForward(output_size, linear_units,
|
|
||||||
dropout_rate),
|
|
||||||
dropout_rate=dropout_rate,
|
|
||||||
normalize_before=normalize_before,
|
|
||||||
concat_after=concat_after) for _ in range(num_blocks)
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
class ConformerEncoder(BaseEncoder):
|
|
||||||
"""Conformer encoder module."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_size: int,
|
|
||||||
output_size: int=256,
|
|
||||||
attention_heads: int=4,
|
|
||||||
linear_units: int=2048,
|
|
||||||
num_blocks: int=6,
|
|
||||||
dropout_rate: float=0.1,
|
|
||||||
positional_dropout_rate: float=0.1,
|
|
||||||
attention_dropout_rate: float=0.0,
|
|
||||||
input_layer: str="conv2d",
|
|
||||||
pos_enc_layer_type: str="rel_pos",
|
|
||||||
normalize_before: bool=True,
|
|
||||||
concat_after: bool=False,
|
|
||||||
static_chunk_size: int=0,
|
|
||||||
use_dynamic_chunk: bool=False,
|
|
||||||
global_cmvn: nn.Layer=None,
|
|
||||||
use_dynamic_left_chunk: bool=False,
|
|
||||||
positionwise_conv_kernel_size: int=1,
|
|
||||||
macaron_style: bool=True,
|
|
||||||
selfattention_layer_type: str="rel_selfattn",
|
|
||||||
activation_type: str="swish",
|
|
||||||
use_cnn_module: bool=True,
|
|
||||||
cnn_module_kernel: int=15,
|
|
||||||
causal: bool=False,
|
|
||||||
cnn_module_norm: str="batch_norm", ):
|
|
||||||
"""Construct ConformerEncoder
|
|
||||||
Args:
|
|
||||||
input_size to use_dynamic_chunk, see in BaseEncoder
|
|
||||||
positionwise_conv_kernel_size (int): Kernel size of positionwise
|
|
||||||
conv1d layer.
|
|
||||||
macaron_style (bool): Whether to use macaron style for
|
|
||||||
positionwise layer.
|
|
||||||
selfattention_layer_type (str): Encoder attention layer type,
|
|
||||||
the parameter has no effect now, it's just for configure
|
|
||||||
compatibility.
|
|
||||||
activation_type (str): Encoder activation function type.
|
|
||||||
use_cnn_module (bool): Whether to use convolution module.
|
|
||||||
cnn_module_kernel (int): Kernel size of convolution module.
|
|
||||||
causal (bool): whether to use causal convolution or not.
|
|
||||||
cnn_module_norm (str): cnn conv norm type, Optional['batch_norm','layer_norm']
|
|
||||||
"""
|
|
||||||
assert check_argument_types()
|
|
||||||
super().__init__(input_size, output_size, attention_heads, linear_units,
|
|
||||||
num_blocks, dropout_rate, positional_dropout_rate,
|
|
||||||
attention_dropout_rate, input_layer,
|
|
||||||
pos_enc_layer_type, normalize_before, concat_after,
|
|
||||||
static_chunk_size, use_dynamic_chunk, global_cmvn,
|
|
||||||
use_dynamic_left_chunk)
|
|
||||||
activation = get_activation(activation_type)
|
|
||||||
|
|
||||||
# self-attention module definition
|
|
||||||
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
|
||||||
encoder_selfattn_layer_args = (attention_heads, output_size,
|
|
||||||
attention_dropout_rate)
|
|
||||||
# feed-forward module definition
|
|
||||||
positionwise_layer = PositionwiseFeedForward
|
|
||||||
positionwise_layer_args = (output_size, linear_units, dropout_rate,
|
|
||||||
activation)
|
|
||||||
# convolution module definition
|
|
||||||
convolution_layer = ConvolutionModule
|
|
||||||
convolution_layer_args = (output_size, cnn_module_kernel, activation,
|
|
||||||
cnn_module_norm, causal)
|
|
||||||
|
|
||||||
self.encoders = nn.LayerList([
|
|
||||||
ConformerEncoderLayer(
|
|
||||||
size=output_size,
|
|
||||||
self_attn=encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
|
||||||
feed_forward=positionwise_layer(*positionwise_layer_args),
|
|
||||||
feed_forward_macaron=positionwise_layer(
|
|
||||||
*positionwise_layer_args) if macaron_style else None,
|
|
||||||
conv_module=convolution_layer(*convolution_layer_args)
|
|
||||||
if use_cnn_module else None,
|
|
||||||
dropout_rate=dropout_rate,
|
|
||||||
normalize_before=normalize_before,
|
|
||||||
concat_after=concat_after) for _ in range(num_blocks)
|
|
||||||
])
|
|
||||||
@ -1,144 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import paddle
|
|
||||||
from paddle import nn
|
|
||||||
from paddle.nn import functional as F
|
|
||||||
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
__all__ = ['CTCLoss', "LabelSmoothingLoss"]
|
|
||||||
|
|
||||||
|
|
||||||
class CTCLoss(nn.Layer):
|
|
||||||
def __init__(self, blank=0, reduction='sum', batch_average=False):
|
|
||||||
super().__init__()
|
|
||||||
# last token id as blank id
|
|
||||||
self.loss = nn.CTCLoss(blank=blank, reduction=reduction)
|
|
||||||
self.batch_average = batch_average
|
|
||||||
|
|
||||||
def forward(self, logits, ys_pad, hlens, ys_lens):
|
|
||||||
"""Compute CTC loss.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
logits ([paddle.Tensor]): [B, Tmax, D]
|
|
||||||
ys_pad ([paddle.Tensor]): [B, Tmax]
|
|
||||||
hlens ([paddle.Tensor]): [B]
|
|
||||||
ys_lens ([paddle.Tensor]): [B]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[paddle.Tensor]: scalar. If reduction is 'none', then (N), where N = \text{batch size}.
|
|
||||||
"""
|
|
||||||
B = paddle.shape(logits)[0]
|
|
||||||
# warp-ctc need logits, and do softmax on logits by itself
|
|
||||||
# warp-ctc need activation with shape [T, B, V + 1]
|
|
||||||
# logits: (B, L, D) -> (L, B, D)
|
|
||||||
logits = logits.transpose([1, 0, 2])
|
|
||||||
# (TODO:Hui Zhang) ctc loss does not support int64 labels
|
|
||||||
ys_pad = ys_pad.astype(paddle.int32)
|
|
||||||
loss = self.loss(
|
|
||||||
logits, ys_pad, hlens, ys_lens, norm_by_times=self.batch_average)
|
|
||||||
if self.batch_average:
|
|
||||||
# Batch-size average
|
|
||||||
loss = loss / B
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
class LabelSmoothingLoss(nn.Layer):
|
|
||||||
"""Label-smoothing loss.
|
|
||||||
In a standard CE loss, the label's data distribution is:
|
|
||||||
[0,1,2] ->
|
|
||||||
[
|
|
||||||
[1.0, 0.0, 0.0],
|
|
||||||
[0.0, 1.0, 0.0],
|
|
||||||
[0.0, 0.0, 1.0],
|
|
||||||
]
|
|
||||||
In the smoothing version CE Loss,some probabilities
|
|
||||||
are taken from the true label prob (1.0) and are divided
|
|
||||||
among other labels.
|
|
||||||
e.g.
|
|
||||||
smoothing=0.1
|
|
||||||
[0,1,2] ->
|
|
||||||
[
|
|
||||||
[0.9, 0.05, 0.05],
|
|
||||||
[0.05, 0.9, 0.05],
|
|
||||||
[0.05, 0.05, 0.9],
|
|
||||||
]
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
size: int,
|
|
||||||
padding_idx: int,
|
|
||||||
smoothing: float,
|
|
||||||
normalize_length: bool=False):
|
|
||||||
"""Label-smoothing loss.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
size (int): the number of class
|
|
||||||
padding_idx (int): padding class id which will be ignored for loss
|
|
||||||
smoothing (float): smoothing rate (0.0 means the conventional CE)
|
|
||||||
normalize_length (bool):
|
|
||||||
True, normalize loss by sequence length;
|
|
||||||
False, normalize loss by batch size.
|
|
||||||
Defaults to False.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.size = size
|
|
||||||
self.padding_idx = padding_idx
|
|
||||||
self.smoothing = smoothing
|
|
||||||
self.confidence = 1.0 - smoothing
|
|
||||||
self.normalize_length = normalize_length
|
|
||||||
self.criterion = nn.KLDivLoss(reduction="none")
|
|
||||||
|
|
||||||
def forward(self, x: paddle.Tensor, target: paddle.Tensor) -> paddle.Tensor:
|
|
||||||
"""Compute loss between x and target.
|
|
||||||
The model outputs and data labels tensors are flatten to
|
|
||||||
(batch*seqlen, class) shape and a mask is applied to the
|
|
||||||
padding part which should not be calculated for loss.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (paddle.Tensor): prediction (batch, seqlen, class)
|
|
||||||
target (paddle.Tensor):
|
|
||||||
target signal masked with self.padding_id (batch, seqlen)
|
|
||||||
Returns:
|
|
||||||
loss (paddle.Tensor) : The KL loss, scalar float value
|
|
||||||
"""
|
|
||||||
B, T, D = paddle.shape(x)
|
|
||||||
assert D == self.size
|
|
||||||
x = x.reshape((-1, self.size))
|
|
||||||
target = target.reshape([-1])
|
|
||||||
|
|
||||||
# use zeros_like instead of torch.no_grad() for true_dist,
|
|
||||||
# since no_grad() can not be exported by JIT
|
|
||||||
true_dist = paddle.full_like(x, self.smoothing / (self.size - 1))
|
|
||||||
ignore = target == self.padding_idx # (B,)
|
|
||||||
|
|
||||||
# target = target * (1 - ignore) # avoid -1 index
|
|
||||||
target = target.masked_fill(ignore, 0) # avoid -1 index
|
|
||||||
# true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
|
|
||||||
target_mask = F.one_hot(target, self.size)
|
|
||||||
true_dist *= (1 - target_mask)
|
|
||||||
true_dist += target_mask * self.confidence
|
|
||||||
|
|
||||||
kl = self.criterion(F.log_softmax(x, axis=1), true_dist)
|
|
||||||
|
|
||||||
#TODO(Hui Zhang): sum not support bool type
|
|
||||||
#total = len(target) - int(ignore.sum())
|
|
||||||
total = len(target) - int(ignore.type_as(target).sum())
|
|
||||||
denom = total if self.normalize_length else B
|
|
||||||
#numer = (kl * (1 - ignore)).sum()
|
|
||||||
numer = kl.masked_fill(ignore.unsqueeze(1), 0).sum()
|
|
||||||
return numer / denom
|
|
||||||
@ -1,260 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import paddle
|
|
||||||
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"make_pad_mask", "make_non_pad_mask", "subsequent_mask",
|
|
||||||
"subsequent_chunk_mask", "add_optional_chunk_mask", "mask_finished_scores",
|
|
||||||
"mask_finished_preds"
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def make_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
|
|
||||||
"""Make mask tensor containing indices of padded part.
|
|
||||||
See description of make_non_pad_mask.
|
|
||||||
Args:
|
|
||||||
lengths (paddle.Tensor): Batch of lengths (B,).
|
|
||||||
Returns:
|
|
||||||
paddle.Tensor: Mask tensor containing indices of padded part.
|
|
||||||
Examples:
|
|
||||||
>>> lengths = [5, 3, 2]
|
|
||||||
>>> make_pad_mask(lengths)
|
|
||||||
masks = [[0, 0, 0, 0 ,0],
|
|
||||||
[0, 0, 0, 1, 1],
|
|
||||||
[0, 0, 1, 1, 1]]
|
|
||||||
"""
|
|
||||||
# (TODO: Hui Zhang): jit not support Tenosr.dim() and Tensor.ndim
|
|
||||||
# assert lengths.dim() == 1
|
|
||||||
batch_size = int(lengths.shape[0])
|
|
||||||
max_len = int(lengths.max())
|
|
||||||
seq_range = paddle.arange(0, max_len, dtype=paddle.int64)
|
|
||||||
seq_range_expand = seq_range.unsqueeze(0).expand([batch_size, max_len])
|
|
||||||
seq_length_expand = lengths.unsqueeze(-1)
|
|
||||||
mask = seq_range_expand >= seq_length_expand
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
|
|
||||||
"""Make mask tensor containing indices of non-padded part.
|
|
||||||
The sequences in a batch may have different lengths. To enable
|
|
||||||
batch computing, padding is need to make all sequence in same
|
|
||||||
size. To avoid the padding part pass value to context dependent
|
|
||||||
block such as attention or convolution , this padding part is
|
|
||||||
masked.
|
|
||||||
This pad_mask is used in both encoder and decoder.
|
|
||||||
1 for non-padded part and 0 for padded part.
|
|
||||||
Args:
|
|
||||||
lengths (paddle.Tensor): Batch of lengths (B,).
|
|
||||||
Returns:
|
|
||||||
paddle.Tensor: mask tensor containing indices of padded part.
|
|
||||||
Examples:
|
|
||||||
>>> lengths = [5, 3, 2]
|
|
||||||
>>> make_non_pad_mask(lengths)
|
|
||||||
masks = [[1, 1, 1, 1 ,1],
|
|
||||||
[1, 1, 1, 0, 0],
|
|
||||||
[1, 1, 0, 0, 0]]
|
|
||||||
"""
|
|
||||||
#TODO(Hui Zhang): return ~make_pad_mask(lengths), not support ~
|
|
||||||
return make_pad_mask(lengths).logical_not()
|
|
||||||
|
|
||||||
|
|
||||||
def subsequent_mask(size: int) -> paddle.Tensor:
|
|
||||||
"""Create mask for subsequent steps (size, size).
|
|
||||||
This mask is used only in decoder which works in an auto-regressive mode.
|
|
||||||
This means the current step could only do attention with its left steps.
|
|
||||||
In encoder, fully attention is used when streaming is not necessary and
|
|
||||||
the sequence is not long. In this case, no attention mask is needed.
|
|
||||||
When streaming is need, chunk-based attention is used in encoder. See
|
|
||||||
subsequent_chunk_mask for the chunk-based attention mask.
|
|
||||||
Args:
|
|
||||||
size (int): size of mask
|
|
||||||
Returns:
|
|
||||||
paddle.Tensor: mask, [size, size]
|
|
||||||
Examples:
|
|
||||||
>>> subsequent_mask(3)
|
|
||||||
[[1, 0, 0],
|
|
||||||
[1, 1, 0],
|
|
||||||
[1, 1, 1]]
|
|
||||||
"""
|
|
||||||
ret = paddle.ones([size, size], dtype=paddle.bool)
|
|
||||||
#TODO(Hui Zhang): tril not support bool
|
|
||||||
#return paddle.tril(ret)
|
|
||||||
ret = ret.astype(paddle.float)
|
|
||||||
ret = paddle.tril(ret)
|
|
||||||
ret = ret.astype(paddle.bool)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def subsequent_chunk_mask(
|
|
||||||
size: int,
|
|
||||||
chunk_size: int,
|
|
||||||
num_left_chunks: int=-1, ) -> paddle.Tensor:
|
|
||||||
"""Create mask for subsequent steps (size, size) with chunk size,
|
|
||||||
this is for streaming encoder
|
|
||||||
Args:
|
|
||||||
size (int): size of mask
|
|
||||||
chunk_size (int): size of chunk
|
|
||||||
num_left_chunks (int): number of left chunks
|
|
||||||
<0: use full chunk
|
|
||||||
>=0: use num_left_chunks
|
|
||||||
Returns:
|
|
||||||
paddle.Tensor: mask, [size, size]
|
|
||||||
Examples:
|
|
||||||
>>> subsequent_chunk_mask(4, 2)
|
|
||||||
[[1, 1, 0, 0],
|
|
||||||
[1, 1, 0, 0],
|
|
||||||
[1, 1, 1, 1],
|
|
||||||
[1, 1, 1, 1]]
|
|
||||||
"""
|
|
||||||
ret = paddle.zeros([size, size], dtype=paddle.bool)
|
|
||||||
for i in range(size):
|
|
||||||
if num_left_chunks < 0:
|
|
||||||
start = 0
|
|
||||||
else:
|
|
||||||
start = max(0, (i // chunk_size - num_left_chunks) * chunk_size)
|
|
||||||
ending = min(size, (i // chunk_size + 1) * chunk_size)
|
|
||||||
ret[i, start:ending] = True
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def add_optional_chunk_mask(xs: paddle.Tensor,
|
|
||||||
masks: paddle.Tensor,
|
|
||||||
use_dynamic_chunk: bool,
|
|
||||||
use_dynamic_left_chunk: bool,
|
|
||||||
decoding_chunk_size: int,
|
|
||||||
static_chunk_size: int,
|
|
||||||
num_decoding_left_chunks: int):
|
|
||||||
""" Apply optional mask for encoder.
|
|
||||||
Args:
|
|
||||||
xs (paddle.Tensor): padded input, (B, L, D), L for max length
|
|
||||||
mask (paddle.Tensor): mask for xs, (B, 1, L)
|
|
||||||
use_dynamic_chunk (bool): whether to use dynamic chunk or not
|
|
||||||
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
|
|
||||||
training.
|
|
||||||
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
|
|
||||||
0: default for training, use random dynamic chunk.
|
|
||||||
<0: for decoding, use full chunk.
|
|
||||||
>0: for decoding, use fixed chunk size as set.
|
|
||||||
static_chunk_size (int): chunk size for static chunk training/decoding
|
|
||||||
if it's greater than 0, if use_dynamic_chunk is true,
|
|
||||||
this parameter will be ignored
|
|
||||||
num_decoding_left_chunks (int): number of left chunks, this is for decoding,
|
|
||||||
the chunk size is decoding_chunk_size.
|
|
||||||
>=0: use num_decoding_left_chunks
|
|
||||||
<0: use all left chunks
|
|
||||||
Returns:
|
|
||||||
paddle.Tensor: chunk mask of the input xs.
|
|
||||||
"""
|
|
||||||
# Whether to use chunk mask or not
|
|
||||||
if use_dynamic_chunk:
|
|
||||||
max_len = xs.shape[1]
|
|
||||||
if decoding_chunk_size < 0:
|
|
||||||
chunk_size = max_len
|
|
||||||
num_left_chunks = -1
|
|
||||||
elif decoding_chunk_size > 0:
|
|
||||||
chunk_size = decoding_chunk_size
|
|
||||||
num_left_chunks = num_decoding_left_chunks
|
|
||||||
else:
|
|
||||||
# chunk size is either [1, 25] or full context(max_len).
|
|
||||||
# Since we use 4 times subsampling and allow up to 1s(100 frames)
|
|
||||||
# delay, the maximum frame is 100 / 4 = 25.
|
|
||||||
chunk_size = int(paddle.randint(1, max_len, (1, )))
|
|
||||||
num_left_chunks = -1
|
|
||||||
if chunk_size > max_len // 2:
|
|
||||||
chunk_size = max_len
|
|
||||||
else:
|
|
||||||
chunk_size = chunk_size % 25 + 1
|
|
||||||
if use_dynamic_left_chunk:
|
|
||||||
max_left_chunks = (max_len - 1) // chunk_size
|
|
||||||
num_left_chunks = int(
|
|
||||||
paddle.randint(0, max_left_chunks, (1, )))
|
|
||||||
chunk_masks = subsequent_chunk_mask(xs.shape[1], chunk_size,
|
|
||||||
num_left_chunks) # (L, L)
|
|
||||||
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
|
||||||
# chunk_masks = masks & chunk_masks # (B, L, L)
|
|
||||||
chunk_masks = masks.logical_and(chunk_masks) # (B, L, L)
|
|
||||||
elif static_chunk_size > 0:
|
|
||||||
num_left_chunks = num_decoding_left_chunks
|
|
||||||
chunk_masks = subsequent_chunk_mask(xs.shape[1], static_chunk_size,
|
|
||||||
num_left_chunks) # (L, L)
|
|
||||||
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
|
||||||
# chunk_masks = masks & chunk_masks # (B, L, L)
|
|
||||||
chunk_masks = masks.logical_and(chunk_masks) # (B, L, L)
|
|
||||||
else:
|
|
||||||
chunk_masks = masks
|
|
||||||
return chunk_masks
|
|
||||||
|
|
||||||
|
|
||||||
def mask_finished_scores(score: paddle.Tensor,
|
|
||||||
flag: paddle.Tensor) -> paddle.Tensor:
|
|
||||||
"""
|
|
||||||
If a sequence is finished, we only allow one alive branch. This function
|
|
||||||
aims to give one branch a zero score and the rest -inf score.
|
|
||||||
Args:
|
|
||||||
score (paddle.Tensor): A real value array with shape
|
|
||||||
(batch_size * beam_size, beam_size).
|
|
||||||
flag (paddle.Tensor): A bool array with shape
|
|
||||||
(batch_size * beam_size, 1).
|
|
||||||
Returns:
|
|
||||||
paddle.Tensor: (batch_size * beam_size, beam_size).
|
|
||||||
Examples:
|
|
||||||
flag: tensor([[ True],
|
|
||||||
[False]])
|
|
||||||
score: tensor([[-0.3666, -0.6664, 0.6019],
|
|
||||||
[-1.1490, -0.2948, 0.7460]])
|
|
||||||
unfinished: tensor([[False, True, True],
|
|
||||||
[False, False, False]])
|
|
||||||
finished: tensor([[ True, False, False],
|
|
||||||
[False, False, False]])
|
|
||||||
return: tensor([[ 0.0000, -inf, -inf],
|
|
||||||
[-1.1490, -0.2948, 0.7460]])
|
|
||||||
"""
|
|
||||||
beam_size = score.shape[-1]
|
|
||||||
zero_mask = paddle.zeros_like(flag, dtype=paddle.bool)
|
|
||||||
if beam_size > 1:
|
|
||||||
unfinished = paddle.concat(
|
|
||||||
(zero_mask, flag.tile([1, beam_size - 1])), axis=1)
|
|
||||||
finished = paddle.concat(
|
|
||||||
(flag, zero_mask.tile([1, beam_size - 1])), axis=1)
|
|
||||||
else:
|
|
||||||
unfinished = zero_mask
|
|
||||||
finished = flag
|
|
||||||
|
|
||||||
# infs = paddle.ones_like(score) * -float('inf')
|
|
||||||
# score = paddle.where(unfinished, infs, score)
|
|
||||||
# score = paddle.where(finished, paddle.zeros_like(score), score)
|
|
||||||
score.masked_fill_(unfinished, -float('inf'))
|
|
||||||
score.masked_fill_(finished, 0)
|
|
||||||
return score
|
|
||||||
|
|
||||||
|
|
||||||
def mask_finished_preds(pred: paddle.Tensor, flag: paddle.Tensor,
|
|
||||||
eos: int) -> paddle.Tensor:
|
|
||||||
"""
|
|
||||||
If a sequence is finished, all of its branch should be <eos>
|
|
||||||
Args:
|
|
||||||
pred (paddle.Tensor): A int array with shape
|
|
||||||
(batch_size * beam_size, beam_size).
|
|
||||||
flag (paddle.Tensor): A bool array with shape
|
|
||||||
(batch_size * beam_size, 1).
|
|
||||||
Returns:
|
|
||||||
paddle.Tensor: (batch_size * beam_size).
|
|
||||||
"""
|
|
||||||
beam_size = pred.shape[-1]
|
|
||||||
finished = flag.repeat(1, beam_size)
|
|
||||||
return pred.masked_fill_(finished, eos)
|
|
||||||
@ -1,13 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
@ -1,54 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""This module provides functions to calculate bleu score in different level.
|
|
||||||
e.g. wer for word-level, cer for char-level.
|
|
||||||
"""
|
|
||||||
import sacrebleu
|
|
||||||
|
|
||||||
__all__ = ['bleu', 'char_bleu']
|
|
||||||
|
|
||||||
|
|
||||||
def bleu(hypothesis, reference):
|
|
||||||
"""Calculate BLEU. BLEU compares reference text and
|
|
||||||
hypothesis text in word-level using scarebleu.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
:param reference: The reference sentences.
|
|
||||||
:type reference: list[list[str]]
|
|
||||||
:param hypothesis: The hypothesis sentence.
|
|
||||||
:type hypothesis: list[str]
|
|
||||||
:raises ValueError: If the reference length is zero.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return sacrebleu.corpus_bleu(hypothesis, reference)
|
|
||||||
|
|
||||||
|
|
||||||
def char_bleu(hypothesis, reference):
|
|
||||||
"""Calculate BLEU. BLEU compares reference text and
|
|
||||||
hypothesis text in char-level using scarebleu.
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
:param reference: The reference sentences.
|
|
||||||
:type reference: list[list[str]]
|
|
||||||
:param hypothesis: The hypothesis sentence.
|
|
||||||
:type hypothesis: list[str]
|
|
||||||
:raises ValueError: If the reference number is zero.
|
|
||||||
"""
|
|
||||||
hypothesis = [' '.join(list(hyp.replace(' ', ''))) for hyp in hypothesis]
|
|
||||||
reference = [[' '.join(list(ref_i.replace(' ', ''))) for ref_i in ref]
|
|
||||||
for ref in reference]
|
|
||||||
|
|
||||||
return sacrebleu.corpus_bleu(hypothesis, reference)
|
|
||||||
@ -1,298 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import glob
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Text
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import paddle
|
|
||||||
from paddle import distributed as dist
|
|
||||||
from paddle.optimizer import Optimizer
|
|
||||||
|
|
||||||
from deepspeech.utils import mp_tools
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
__all__ = ["Checkpoint"]
|
|
||||||
|
|
||||||
|
|
||||||
class Checkpoint():
|
|
||||||
def __init__(self, kbest_n: int=5, latest_n: int=1):
|
|
||||||
self.best_records: Mapping[Path, float] = {}
|
|
||||||
self.latest_records = []
|
|
||||||
self.kbest_n = kbest_n
|
|
||||||
self.latest_n = latest_n
|
|
||||||
self._save_all = (kbest_n == -1)
|
|
||||||
|
|
||||||
def add_checkpoint(self,
|
|
||||||
checkpoint_dir,
|
|
||||||
tag_or_iteration: Union[int, Text],
|
|
||||||
model: paddle.nn.Layer,
|
|
||||||
optimizer: Optimizer=None,
|
|
||||||
infos: dict=None,
|
|
||||||
metric_type="val_loss"):
|
|
||||||
"""Save checkpoint in best_n and latest_n.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
|
||||||
tag_or_iteration (int or str): the latest iteration(step or epoch) number or tag.
|
|
||||||
model (Layer): model to be checkpointed.
|
|
||||||
optimizer (Optimizer, optional): optimizer to be checkpointed.
|
|
||||||
infos (dict or None)): any info you want to save.
|
|
||||||
metric_type (str, optional): metric type. Defaults to "val_loss".
|
|
||||||
"""
|
|
||||||
if (metric_type not in infos.keys()):
|
|
||||||
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
|
|
||||||
optimizer, infos)
|
|
||||||
return
|
|
||||||
|
|
||||||
#save best
|
|
||||||
if self._should_save_best(infos[metric_type]):
|
|
||||||
self._save_best_checkpoint_and_update(
|
|
||||||
infos[metric_type], checkpoint_dir, tag_or_iteration, model,
|
|
||||||
optimizer, infos)
|
|
||||||
#save latest
|
|
||||||
self._save_latest_checkpoint_and_update(
|
|
||||||
checkpoint_dir, tag_or_iteration, model, optimizer, infos)
|
|
||||||
|
|
||||||
if isinstance(tag_or_iteration, int):
|
|
||||||
self._save_checkpoint_record(checkpoint_dir, tag_or_iteration)
|
|
||||||
|
|
||||||
def load_parameters(self,
|
|
||||||
model,
|
|
||||||
optimizer=None,
|
|
||||||
checkpoint_dir=None,
|
|
||||||
checkpoint_path=None,
|
|
||||||
record_file="checkpoint_latest"):
|
|
||||||
"""Load a last model checkpoint from disk.
|
|
||||||
Args:
|
|
||||||
model (Layer): model to load parameters.
|
|
||||||
optimizer (Optimizer, optional): optimizer to load states if needed.
|
|
||||||
Defaults to None.
|
|
||||||
checkpoint_dir (str, optional): the directory where checkpoint is saved.
|
|
||||||
checkpoint_path (str, optional): if specified, load the checkpoint
|
|
||||||
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
|
|
||||||
be ignored. Defaults to None.
|
|
||||||
record_file "checkpoint_latest" or "checkpoint_best"
|
|
||||||
Returns:
|
|
||||||
configs (dict): epoch or step, lr and other meta info should be saved.
|
|
||||||
"""
|
|
||||||
configs = {}
|
|
||||||
|
|
||||||
if checkpoint_path is not None:
|
|
||||||
pass
|
|
||||||
elif checkpoint_dir is not None and record_file is not None:
|
|
||||||
# load checkpint from record file
|
|
||||||
checkpoint_record = os.path.join(checkpoint_dir, record_file)
|
|
||||||
iteration = self._load_checkpoint_idx(checkpoint_record)
|
|
||||||
if iteration == -1:
|
|
||||||
return configs
|
|
||||||
checkpoint_path = os.path.join(checkpoint_dir,
|
|
||||||
"{}".format(iteration))
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"At least one of 'checkpoint_path' or 'checkpoint_dir' should be specified!"
|
|
||||||
)
|
|
||||||
|
|
||||||
rank = dist.get_rank()
|
|
||||||
|
|
||||||
params_path = checkpoint_path + ".pdparams"
|
|
||||||
model_dict = paddle.load(params_path)
|
|
||||||
model.set_state_dict(model_dict)
|
|
||||||
logger.info("Rank {}: loaded model from {}".format(rank, params_path))
|
|
||||||
|
|
||||||
optimizer_path = checkpoint_path + ".pdopt"
|
|
||||||
if optimizer and os.path.isfile(optimizer_path):
|
|
||||||
optimizer_dict = paddle.load(optimizer_path)
|
|
||||||
optimizer.set_state_dict(optimizer_dict)
|
|
||||||
logger.info("Rank {}: loaded optimizer state from {}".format(
|
|
||||||
rank, optimizer_path))
|
|
||||||
|
|
||||||
info_path = re.sub('.pdparams$', '.json', params_path)
|
|
||||||
if os.path.exists(info_path):
|
|
||||||
with open(info_path, 'r') as fin:
|
|
||||||
configs = json.load(fin)
|
|
||||||
return configs
|
|
||||||
|
|
||||||
def load_latest_parameters(self,
|
|
||||||
model,
|
|
||||||
optimizer=None,
|
|
||||||
checkpoint_dir=None,
|
|
||||||
checkpoint_path=None):
|
|
||||||
"""Load a last model checkpoint from disk.
|
|
||||||
Args:
|
|
||||||
model (Layer): model to load parameters.
|
|
||||||
optimizer (Optimizer, optional): optimizer to load states if needed.
|
|
||||||
Defaults to None.
|
|
||||||
checkpoint_dir (str, optional): the directory where checkpoint is saved.
|
|
||||||
checkpoint_path (str, optional): if specified, load the checkpoint
|
|
||||||
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
|
|
||||||
be ignored. Defaults to None.
|
|
||||||
Returns:
|
|
||||||
configs (dict): epoch or step, lr and other meta info should be saved.
|
|
||||||
"""
|
|
||||||
return self.load_parameters(model, optimizer, checkpoint_dir,
|
|
||||||
checkpoint_path, "checkpoint_latest")
|
|
||||||
|
|
||||||
def load_best_parameters(self,
|
|
||||||
model,
|
|
||||||
optimizer=None,
|
|
||||||
checkpoint_dir=None,
|
|
||||||
checkpoint_path=None):
|
|
||||||
"""Load a last model checkpoint from disk.
|
|
||||||
Args:
|
|
||||||
model (Layer): model to load parameters.
|
|
||||||
optimizer (Optimizer, optional): optimizer to load states if needed.
|
|
||||||
Defaults to None.
|
|
||||||
checkpoint_dir (str, optional): the directory where checkpoint is saved.
|
|
||||||
checkpoint_path (str, optional): if specified, load the checkpoint
|
|
||||||
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
|
|
||||||
be ignored. Defaults to None.
|
|
||||||
Returns:
|
|
||||||
configs (dict): epoch or step, lr and other meta info should be saved.
|
|
||||||
"""
|
|
||||||
return self.load_parameters(model, optimizer, checkpoint_dir,
|
|
||||||
checkpoint_path, "checkpoint_best")
|
|
||||||
|
|
||||||
def _should_save_best(self, metric: float) -> bool:
|
|
||||||
if not self._best_full():
|
|
||||||
return True
|
|
||||||
|
|
||||||
# already full
|
|
||||||
worst_record_path = max(self.best_records, key=self.best_records.get)
|
|
||||||
# worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0]
|
|
||||||
worst_metric = self.best_records[worst_record_path]
|
|
||||||
return metric < worst_metric
|
|
||||||
|
|
||||||
def _best_full(self):
|
|
||||||
return (not self._save_all) and len(self.best_records) == self.kbest_n
|
|
||||||
|
|
||||||
def _latest_full(self):
|
|
||||||
return len(self.latest_records) == self.latest_n
|
|
||||||
|
|
||||||
def _save_best_checkpoint_and_update(self, metric, checkpoint_dir,
|
|
||||||
tag_or_iteration, model, optimizer,
|
|
||||||
infos):
|
|
||||||
# remove the worst
|
|
||||||
if self._best_full():
|
|
||||||
worst_record_path = max(self.best_records,
|
|
||||||
key=self.best_records.get)
|
|
||||||
self.best_records.pop(worst_record_path)
|
|
||||||
if (worst_record_path not in self.latest_records):
|
|
||||||
logger.info(
|
|
||||||
"remove the worst checkpoint: {}".format(worst_record_path))
|
|
||||||
self._del_checkpoint(checkpoint_dir, worst_record_path)
|
|
||||||
|
|
||||||
# add the new one
|
|
||||||
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
|
|
||||||
optimizer, infos)
|
|
||||||
self.best_records[tag_or_iteration] = metric
|
|
||||||
|
|
||||||
def _save_latest_checkpoint_and_update(
|
|
||||||
self, checkpoint_dir, tag_or_iteration, model, optimizer, infos):
|
|
||||||
# remove the old
|
|
||||||
if self._latest_full():
|
|
||||||
to_del_fn = self.latest_records.pop(0)
|
|
||||||
if (to_del_fn not in self.best_records.keys()):
|
|
||||||
logger.info(
|
|
||||||
"remove the latest checkpoint: {}".format(to_del_fn))
|
|
||||||
self._del_checkpoint(checkpoint_dir, to_del_fn)
|
|
||||||
self.latest_records.append(tag_or_iteration)
|
|
||||||
|
|
||||||
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
|
|
||||||
optimizer, infos)
|
|
||||||
|
|
||||||
def _del_checkpoint(self, checkpoint_dir, tag_or_iteration):
|
|
||||||
checkpoint_path = os.path.join(checkpoint_dir,
|
|
||||||
"{}".format(tag_or_iteration))
|
|
||||||
for filename in glob.glob(checkpoint_path + ".*"):
|
|
||||||
os.remove(filename)
|
|
||||||
logger.info("delete file: {}".format(filename))
|
|
||||||
|
|
||||||
def _load_checkpoint_idx(self, checkpoint_record: str) -> int:
|
|
||||||
"""Get the iteration number corresponding to the latest saved checkpoint.
|
|
||||||
Args:
|
|
||||||
checkpoint_path (str): the saved path of checkpoint.
|
|
||||||
Returns:
|
|
||||||
int: the latest iteration number. -1 for no checkpoint to load.
|
|
||||||
"""
|
|
||||||
if not os.path.isfile(checkpoint_record):
|
|
||||||
return -1
|
|
||||||
|
|
||||||
# Fetch the latest checkpoint index.
|
|
||||||
with open(checkpoint_record, "rt") as handle:
|
|
||||||
latest_checkpoint = handle.readlines()[-1].strip()
|
|
||||||
iteration = int(latest_checkpoint.split(":")[-1])
|
|
||||||
return iteration
|
|
||||||
|
|
||||||
def _save_checkpoint_record(self, checkpoint_dir: str, iteration: int):
|
|
||||||
"""Save the iteration number of the latest model to be checkpoint record.
|
|
||||||
Args:
|
|
||||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
|
||||||
iteration (int): the latest iteration number.
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
checkpoint_record_latest = os.path.join(checkpoint_dir,
|
|
||||||
"checkpoint_latest")
|
|
||||||
checkpoint_record_best = os.path.join(checkpoint_dir, "checkpoint_best")
|
|
||||||
|
|
||||||
with open(checkpoint_record_best, "w") as handle:
|
|
||||||
for i in self.best_records.keys():
|
|
||||||
handle.write("model_checkpoint_path:{}\n".format(i))
|
|
||||||
with open(checkpoint_record_latest, "w") as handle:
|
|
||||||
for i in self.latest_records:
|
|
||||||
handle.write("model_checkpoint_path:{}\n".format(i))
|
|
||||||
|
|
||||||
@mp_tools.rank_zero_only
|
|
||||||
def _save_parameters(self,
|
|
||||||
checkpoint_dir: str,
|
|
||||||
tag_or_iteration: Union[int, str],
|
|
||||||
model: paddle.nn.Layer,
|
|
||||||
optimizer: Optimizer=None,
|
|
||||||
infos: dict=None):
|
|
||||||
"""Checkpoint the latest trained model parameters.
|
|
||||||
Args:
|
|
||||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
|
||||||
tag_or_iteration (int or str): the latest iteration(step or epoch) number.
|
|
||||||
model (Layer): model to be checkpointed.
|
|
||||||
optimizer (Optimizer, optional): optimizer to be checkpointed.
|
|
||||||
Defaults to None.
|
|
||||||
infos (dict or None): any info you want to save.
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
checkpoint_path = os.path.join(checkpoint_dir,
|
|
||||||
"{}".format(tag_or_iteration))
|
|
||||||
|
|
||||||
model_dict = model.state_dict()
|
|
||||||
params_path = checkpoint_path + ".pdparams"
|
|
||||||
paddle.save(model_dict, params_path)
|
|
||||||
logger.info("Saved model to {}".format(params_path))
|
|
||||||
|
|
||||||
if optimizer:
|
|
||||||
opt_dict = optimizer.state_dict()
|
|
||||||
optimizer_path = checkpoint_path + ".pdopt"
|
|
||||||
paddle.save(opt_dict, optimizer_path)
|
|
||||||
logger.info("Saved optimzier state to {}".format(optimizer_path))
|
|
||||||
|
|
||||||
info_path = re.sub('.pdparams$', '.json', params_path)
|
|
||||||
infos = {} if infos is None else infos
|
|
||||||
with open(info_path, 'w') as fout:
|
|
||||||
data = json.dumps(infos)
|
|
||||||
fout.write(data)
|
|
||||||
@ -1,134 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import paddle
|
|
||||||
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
__all__ = ["forced_align", "remove_duplicates_and_blank", "insert_blank"]
|
|
||||||
|
|
||||||
|
|
||||||
def remove_duplicates_and_blank(hyp: List[int], blank_id=0) -> List[int]:
|
|
||||||
"""ctc alignment to ctc label ids.
|
|
||||||
|
|
||||||
"abaa-acee-" -> "abaace"
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hyp (List[int]): hypotheses ids, (L)
|
|
||||||
blank_id (int, optional): blank id. Defaults to 0.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[int]: remove dupicate ids, then remove blank id.
|
|
||||||
"""
|
|
||||||
new_hyp: List[int] = []
|
|
||||||
cur = 0
|
|
||||||
while cur < len(hyp):
|
|
||||||
# add non-blank into new_hyp
|
|
||||||
if hyp[cur] != blank_id:
|
|
||||||
new_hyp.append(hyp[cur])
|
|
||||||
# skip repeat label
|
|
||||||
prev = cur
|
|
||||||
while cur < len(hyp) and hyp[cur] == hyp[prev]:
|
|
||||||
cur += 1
|
|
||||||
return new_hyp
|
|
||||||
|
|
||||||
|
|
||||||
def insert_blank(label: np.ndarray, blank_id: int=0) -> np.ndarray:
|
|
||||||
"""Insert blank token between every two label token.
|
|
||||||
|
|
||||||
"abcdefg" -> "-a-b-c-d-e-f-g-"
|
|
||||||
|
|
||||||
Args:
|
|
||||||
label ([np.ndarray]): label ids, List[int], (L).
|
|
||||||
blank_id (int, optional): blank id. Defaults to 0.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[np.ndarray]: (2L+1).
|
|
||||||
"""
|
|
||||||
label = np.expand_dims(label, 1) #[L, 1]
|
|
||||||
blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id
|
|
||||||
label = np.concatenate([blanks, label], axis=1) #[L, 2]
|
|
||||||
label = label.reshape(-1) #[2L], -l-l-l
|
|
||||||
label = np.append(label, label[0]) #[2L + 1], -l-l-l-
|
|
||||||
return label
|
|
||||||
|
|
||||||
|
|
||||||
def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
|
|
||||||
blank_id=0) -> List[int]:
|
|
||||||
"""ctc forced alignment.
|
|
||||||
|
|
||||||
https://distill.pub/2017/ctc/
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ctc_probs (paddle.Tensor): hidden state sequence, 2d tensor (T, D)
|
|
||||||
y (paddle.Tensor): label id sequence tensor, 1d tensor (L)
|
|
||||||
blank_id (int): blank symbol index
|
|
||||||
Returns:
|
|
||||||
List[int]: best alignment result, (T).
|
|
||||||
"""
|
|
||||||
y_insert_blank = insert_blank(y, blank_id) #(2L+1)
|
|
||||||
|
|
||||||
log_alpha = paddle.zeros(
|
|
||||||
(ctc_probs.size(0), len(y_insert_blank))) #(T, 2L+1)
|
|
||||||
log_alpha = log_alpha - float('inf') # log of zero
|
|
||||||
# TODO(Hui Zhang): zeros not support paddle.int16
|
|
||||||
state_path = (paddle.zeros(
|
|
||||||
(ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int32) - 1
|
|
||||||
) # state path, Tuple((T, 2L+1))
|
|
||||||
|
|
||||||
# init start state
|
|
||||||
# TODO(Hui Zhang): VarBase.__getitem__() not support np.int64
|
|
||||||
log_alpha[0, 0] = ctc_probs[0][int(y_insert_blank[0])] # State-b, Sb
|
|
||||||
log_alpha[0, 1] = ctc_probs[0][int(y_insert_blank[1])] # State-nb, Snb
|
|
||||||
|
|
||||||
for t in range(1, ctc_probs.size(0)): # T
|
|
||||||
for s in range(len(y_insert_blank)): # 2L+1
|
|
||||||
if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[
|
|
||||||
s] == y_insert_blank[s - 2]:
|
|
||||||
candidates = paddle.to_tensor(
|
|
||||||
[log_alpha[t - 1, s], log_alpha[t - 1, s - 1]])
|
|
||||||
prev_state = [s, s - 1]
|
|
||||||
else:
|
|
||||||
candidates = paddle.to_tensor([
|
|
||||||
log_alpha[t - 1, s],
|
|
||||||
log_alpha[t - 1, s - 1],
|
|
||||||
log_alpha[t - 1, s - 2],
|
|
||||||
])
|
|
||||||
prev_state = [s, s - 1, s - 2]
|
|
||||||
# TODO(Hui Zhang): VarBase.__getitem__() not support np.int64
|
|
||||||
log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][int(
|
|
||||||
y_insert_blank[s])]
|
|
||||||
state_path[t, s] = prev_state[paddle.argmax(candidates)]
|
|
||||||
|
|
||||||
# TODO(Hui Zhang): zeros not support paddle.int16
|
|
||||||
state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int32)
|
|
||||||
|
|
||||||
candidates = paddle.to_tensor([
|
|
||||||
log_alpha[-1, len(y_insert_blank) - 1], # Sb
|
|
||||||
log_alpha[-1, len(y_insert_blank) - 2] # Snb
|
|
||||||
])
|
|
||||||
prev_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2]
|
|
||||||
state_seq[-1] = prev_state[paddle.argmax(candidates)]
|
|
||||||
for t in range(ctc_probs.size(0) - 2, -1, -1):
|
|
||||||
state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]]
|
|
||||||
|
|
||||||
output_alignment = []
|
|
||||||
for t in range(0, ctc_probs.size(0)):
|
|
||||||
output_alignment.append(y_insert_blank[state_seq[t, 0]])
|
|
||||||
|
|
||||||
return output_alignment
|
|
||||||
@ -1,67 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import importlib
|
|
||||||
import inspect
|
|
||||||
from typing import Any
|
|
||||||
from typing import Dict
|
|
||||||
from typing import List
|
|
||||||
from typing import Text
|
|
||||||
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
from deepspeech.utils.tensor_utils import has_tensor
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
__all__ = ["dynamic_import", "instance_class"]
|
|
||||||
|
|
||||||
|
|
||||||
def dynamic_import(import_path, alias=dict()):
|
|
||||||
"""dynamic import module and class
|
|
||||||
|
|
||||||
:param str import_path: syntax 'module_name:class_name'
|
|
||||||
e.g., 'deepspeech.models.u2:U2Model'
|
|
||||||
:param dict alias: shortcut for registered class
|
|
||||||
:return: imported class
|
|
||||||
"""
|
|
||||||
if import_path not in alias and ":" not in import_path:
|
|
||||||
raise ValueError("import_path should be one of {} or "
|
|
||||||
'include ":", e.g. "deepspeech.models.u2:U2Model" : '
|
|
||||||
"{}".format(set(alias), import_path))
|
|
||||||
if ":" not in import_path:
|
|
||||||
import_path = alias[import_path]
|
|
||||||
|
|
||||||
module_name, objname = import_path.split(":")
|
|
||||||
m = importlib.import_module(module_name)
|
|
||||||
return getattr(m, objname)
|
|
||||||
|
|
||||||
|
|
||||||
def filter_valid_args(args: Dict[Text, Any], valid_keys: List[Text]):
|
|
||||||
# filter by `valid_keys` and filter `val` is not None
|
|
||||||
new_args = {
|
|
||||||
key: val
|
|
||||||
for key, val in args.items() if (key in valid_keys and val is not None)
|
|
||||||
}
|
|
||||||
return new_args
|
|
||||||
|
|
||||||
|
|
||||||
def filter_out_tenosr(args: Dict[Text, Any]):
|
|
||||||
return {key: val for key, val in args.items() if not has_tensor(val)}
|
|
||||||
|
|
||||||
|
|
||||||
def instance_class(module_class, args: Dict[Text, Any]):
|
|
||||||
valid_keys = inspect.signature(module_class).parameters.keys()
|
|
||||||
new_args = filter_valid_args(args, valid_keys)
|
|
||||||
logger.info(
|
|
||||||
f"Instance: {module_class.__name__} {filter_out_tenosr(new_args)}.")
|
|
||||||
return module_class(**new_args)
|
|
||||||
@ -1,206 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""This module provides functions to calculate error rate in different level.
|
|
||||||
e.g. wer for word-level, cer for char-level.
|
|
||||||
"""
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
__all__ = ['word_errors', 'char_errors', 'wer', 'cer']
|
|
||||||
|
|
||||||
|
|
||||||
def _levenshtein_distance(ref, hyp):
|
|
||||||
"""Levenshtein distance is a string metric for measuring the difference
|
|
||||||
between two sequences. Informally, the levenshtein disctance is defined as
|
|
||||||
the minimum number of single-character edits (substitutions, insertions or
|
|
||||||
deletions) required to change one word into the other. We can naturally
|
|
||||||
extend the edits to word level when calculate levenshtein disctance for
|
|
||||||
two sentences.
|
|
||||||
"""
|
|
||||||
m = len(ref)
|
|
||||||
n = len(hyp)
|
|
||||||
|
|
||||||
# special case
|
|
||||||
if ref == hyp:
|
|
||||||
return 0
|
|
||||||
if m == 0:
|
|
||||||
return n
|
|
||||||
if n == 0:
|
|
||||||
return m
|
|
||||||
|
|
||||||
if m < n:
|
|
||||||
ref, hyp = hyp, ref
|
|
||||||
m, n = n, m
|
|
||||||
|
|
||||||
# use O(min(m, n)) space
|
|
||||||
distance = np.zeros((2, n + 1), dtype=np.int32)
|
|
||||||
|
|
||||||
# initialize distance matrix
|
|
||||||
for j in range(n + 1):
|
|
||||||
distance[0][j] = j
|
|
||||||
|
|
||||||
# calculate levenshtein distance
|
|
||||||
for i in range(1, m + 1):
|
|
||||||
prev_row_idx = (i - 1) % 2
|
|
||||||
cur_row_idx = i % 2
|
|
||||||
distance[cur_row_idx][0] = i
|
|
||||||
for j in range(1, n + 1):
|
|
||||||
if ref[i - 1] == hyp[j - 1]:
|
|
||||||
distance[cur_row_idx][j] = distance[prev_row_idx][j - 1]
|
|
||||||
else:
|
|
||||||
s_num = distance[prev_row_idx][j - 1] + 1
|
|
||||||
i_num = distance[cur_row_idx][j - 1] + 1
|
|
||||||
d_num = distance[prev_row_idx][j] + 1
|
|
||||||
distance[cur_row_idx][j] = min(s_num, i_num, d_num)
|
|
||||||
|
|
||||||
return distance[m % 2][n]
|
|
||||||
|
|
||||||
|
|
||||||
def word_errors(reference, hypothesis, ignore_case=False, delimiter=' '):
|
|
||||||
"""Compute the levenshtein distance between reference sequence and
|
|
||||||
hypothesis sequence in word-level.
|
|
||||||
|
|
||||||
:param reference: The reference sentence.
|
|
||||||
:type reference: str
|
|
||||||
:param hypothesis: The hypothesis sentence.
|
|
||||||
:type hypothesis: str
|
|
||||||
:param ignore_case: Whether case-sensitive or not.
|
|
||||||
:type ignore_case: bool
|
|
||||||
:param delimiter: Delimiter of input sentences.
|
|
||||||
:type delimiter: char
|
|
||||||
:return: Levenshtein distance and word number of reference sentence.
|
|
||||||
:rtype: list
|
|
||||||
"""
|
|
||||||
if ignore_case:
|
|
||||||
reference = reference.lower()
|
|
||||||
hypothesis = hypothesis.lower()
|
|
||||||
|
|
||||||
ref_words = list(filter(None, reference.split(delimiter)))
|
|
||||||
hyp_words = list(filter(None, hypothesis.split(delimiter)))
|
|
||||||
|
|
||||||
edit_distance = _levenshtein_distance(ref_words, hyp_words)
|
|
||||||
return float(edit_distance), len(ref_words)
|
|
||||||
|
|
||||||
|
|
||||||
def char_errors(reference, hypothesis, ignore_case=False, remove_space=False):
|
|
||||||
"""Compute the levenshtein distance between reference sequence and
|
|
||||||
hypothesis sequence in char-level.
|
|
||||||
|
|
||||||
:param reference: The reference sentence.
|
|
||||||
:type reference: str
|
|
||||||
:param hypothesis: The hypothesis sentence.
|
|
||||||
:type hypothesis: str
|
|
||||||
:param ignore_case: Whether case-sensitive or not.
|
|
||||||
:type ignore_case: bool
|
|
||||||
:param remove_space: Whether remove internal space characters
|
|
||||||
:type remove_space: bool
|
|
||||||
:return: Levenshtein distance and length of reference sentence.
|
|
||||||
:rtype: list
|
|
||||||
"""
|
|
||||||
if ignore_case:
|
|
||||||
reference = reference.lower()
|
|
||||||
hypothesis = hypothesis.lower()
|
|
||||||
|
|
||||||
join_char = ' '
|
|
||||||
if remove_space:
|
|
||||||
join_char = ''
|
|
||||||
|
|
||||||
reference = join_char.join(list(filter(None, reference.split(' '))))
|
|
||||||
hypothesis = join_char.join(list(filter(None, hypothesis.split(' '))))
|
|
||||||
|
|
||||||
edit_distance = _levenshtein_distance(reference, hypothesis)
|
|
||||||
return float(edit_distance), len(reference)
|
|
||||||
|
|
||||||
|
|
||||||
def wer(reference, hypothesis, ignore_case=False, delimiter=' '):
|
|
||||||
"""Calculate word error rate (WER). WER compares reference text and
|
|
||||||
hypothesis text in word-level. WER is defined as:
|
|
||||||
|
|
||||||
.. math::
|
|
||||||
WER = (Sw + Dw + Iw) / Nw
|
|
||||||
|
|
||||||
where
|
|
||||||
|
|
||||||
.. code-block:: text
|
|
||||||
|
|
||||||
Sw is the number of words subsituted,
|
|
||||||
Dw is the number of words deleted,
|
|
||||||
Iw is the number of words inserted,
|
|
||||||
Nw is the number of words in the reference
|
|
||||||
|
|
||||||
We can use levenshtein distance to calculate WER. Please draw an attention
|
|
||||||
that empty items will be removed when splitting sentences by delimiter.
|
|
||||||
|
|
||||||
:param reference: The reference sentence.
|
|
||||||
:type reference: str
|
|
||||||
:param hypothesis: The hypothesis sentence.
|
|
||||||
:type hypothesis: str
|
|
||||||
:param ignore_case: Whether case-sensitive or not.
|
|
||||||
:type ignore_case: bool
|
|
||||||
:param delimiter: Delimiter of input sentences.
|
|
||||||
:type delimiter: char
|
|
||||||
:return: Word error rate.
|
|
||||||
:rtype: float
|
|
||||||
:raises ValueError: If word number of reference is zero.
|
|
||||||
"""
|
|
||||||
edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case,
|
|
||||||
delimiter)
|
|
||||||
|
|
||||||
if ref_len == 0:
|
|
||||||
raise ValueError("Reference's word number should be greater than 0.")
|
|
||||||
|
|
||||||
wer = float(edit_distance) / ref_len
|
|
||||||
return wer
|
|
||||||
|
|
||||||
|
|
||||||
def cer(reference, hypothesis, ignore_case=False, remove_space=False):
|
|
||||||
"""Calculate charactor error rate (CER). CER compares reference text and
|
|
||||||
hypothesis text in char-level. CER is defined as:
|
|
||||||
|
|
||||||
.. math::
|
|
||||||
CER = (Sc + Dc + Ic) / Nc
|
|
||||||
|
|
||||||
where
|
|
||||||
|
|
||||||
.. code-block:: text
|
|
||||||
|
|
||||||
Sc is the number of characters substituted,
|
|
||||||
Dc is the number of characters deleted,
|
|
||||||
Ic is the number of characters inserted
|
|
||||||
Nc is the number of characters in the reference
|
|
||||||
|
|
||||||
We can use levenshtein distance to calculate CER. Chinese input should be
|
|
||||||
encoded to unicode. Please draw an attention that the leading and tailing
|
|
||||||
space characters will be truncated and multiple consecutive space
|
|
||||||
characters in a sentence will be replaced by one space character.
|
|
||||||
|
|
||||||
:param reference: The reference sentence.
|
|
||||||
:type reference: str
|
|
||||||
:param hypothesis: The hypothesis sentence.
|
|
||||||
:type hypothesis: str
|
|
||||||
:param ignore_case: Whether case-sensitive or not.
|
|
||||||
:type ignore_case: bool
|
|
||||||
:param remove_space: Whether remove internal space characters
|
|
||||||
:type remove_space: bool
|
|
||||||
:return: Character error rate.
|
|
||||||
:rtype: float
|
|
||||||
:raises ValueError: If the reference length is zero.
|
|
||||||
"""
|
|
||||||
edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case,
|
|
||||||
remove_space)
|
|
||||||
|
|
||||||
if ref_len == 0:
|
|
||||||
raise ValueError("Length of reference should be greater than 0.")
|
|
||||||
|
|
||||||
cer = float(edit_distance) / ref_len
|
|
||||||
return cer
|
|
||||||
@ -1,88 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import numpy as np
|
|
||||||
from paddle import nn
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"summary", "gradient_norm", "freeze", "unfreeze", "print_grads",
|
|
||||||
"print_params"
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def summary(layer: nn.Layer, print_func=print):
|
|
||||||
if print_func is None:
|
|
||||||
return
|
|
||||||
num_params = num_elements = 0
|
|
||||||
for name, param in layer.state_dict().items():
|
|
||||||
if print_func:
|
|
||||||
print_func(
|
|
||||||
"{} | {} | {}".format(name, param.shape, np.prod(param.shape)))
|
|
||||||
num_elements += np.prod(param.shape)
|
|
||||||
num_params += 1
|
|
||||||
if print_func:
|
|
||||||
num_elements = num_elements / 1024**2
|
|
||||||
print_func(
|
|
||||||
f"Total parameters: {num_params}, {num_elements:.2f}M elements.")
|
|
||||||
|
|
||||||
|
|
||||||
def print_grads(model, print_func=print):
|
|
||||||
if print_func is None:
|
|
||||||
return
|
|
||||||
for n, p in model.named_parameters():
|
|
||||||
msg = f"param grad: {n}: shape: {p.shape} grad: {p.grad}"
|
|
||||||
print_func(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def print_params(model, print_func=print):
|
|
||||||
if print_func is None:
|
|
||||||
return
|
|
||||||
total = 0.0
|
|
||||||
num_params = 0.0
|
|
||||||
for n, p in model.named_parameters():
|
|
||||||
msg = f"{n} | {p.shape} | {np.prod(p.shape)} | {not p.stop_gradient}"
|
|
||||||
total += np.prod(p.shape)
|
|
||||||
num_params += 1
|
|
||||||
if print_func:
|
|
||||||
print_func(msg)
|
|
||||||
if print_func:
|
|
||||||
total = total / 1024**2
|
|
||||||
print_func(f"Total parameters: {num_params}, {total:.2f}M elements.")
|
|
||||||
|
|
||||||
|
|
||||||
def gradient_norm(layer: nn.Layer):
|
|
||||||
grad_norm_dict = {}
|
|
||||||
for name, param in layer.state_dict().items():
|
|
||||||
if param.trainable:
|
|
||||||
grad = param.gradient() # return numpy.ndarray
|
|
||||||
grad_norm_dict[name] = np.linalg.norm(grad) / grad.size
|
|
||||||
return grad_norm_dict
|
|
||||||
|
|
||||||
|
|
||||||
def recursively_remove_weight_norm(layer: nn.Layer):
|
|
||||||
for layer in layer.sublayers():
|
|
||||||
try:
|
|
||||||
nn.utils.remove_weight_norm(layer)
|
|
||||||
except ValueError as e:
|
|
||||||
# ther is not weight norm hoom in this layer
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def freeze(layer: nn.Layer):
|
|
||||||
for param in layer.parameters():
|
|
||||||
param.trainable = False
|
|
||||||
|
|
||||||
|
|
||||||
def unfreeze(layer: nn.Layer):
|
|
||||||
for param in layer.parameters():
|
|
||||||
param.trainable = True
|
|
||||||
@ -1,182 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import getpass
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import socket
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from paddle import inference
|
|
||||||
|
|
||||||
FORMAT_STR = '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
|
|
||||||
DATE_FMT_STR = '%Y/%m/%d %H:%M:%S'
|
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.DEBUG, format=FORMAT_STR, datefmt=DATE_FMT_STR)
|
|
||||||
|
|
||||||
|
|
||||||
def find_log_dir(log_dir=None):
|
|
||||||
"""Returns the most suitable directory to put log files into.
|
|
||||||
Args:
|
|
||||||
log_dir: str|None, if specified, the logfile(s) will be created in that
|
|
||||||
directory. Otherwise if the --log_dir command-line flag is provided,
|
|
||||||
the logfile will be created in that directory. Otherwise the logfile
|
|
||||||
will be created in a standard location.
|
|
||||||
Raises:
|
|
||||||
FileNotFoundError: raised when it cannot find a log directory.
|
|
||||||
"""
|
|
||||||
# Get a list of possible log dirs (will try to use them in order).
|
|
||||||
if log_dir:
|
|
||||||
# log_dir was explicitly specified as an arg, so use it and it alone.
|
|
||||||
dirs = [log_dir]
|
|
||||||
else:
|
|
||||||
dirs = ['/tmp/', './']
|
|
||||||
|
|
||||||
# Find the first usable log dir.
|
|
||||||
for d in dirs:
|
|
||||||
if os.path.isdir(d) and os.access(d, os.W_OK):
|
|
||||||
return d
|
|
||||||
raise FileNotFoundError(
|
|
||||||
"Can't find a writable directory for logs, tried %s" % dirs)
|
|
||||||
|
|
||||||
|
|
||||||
def find_log_dir_and_names(program_name=None, log_dir=None):
|
|
||||||
"""Computes the directory and filename prefix for log file.
|
|
||||||
Args:
|
|
||||||
program_name: str|None, the filename part of the path to the program that
|
|
||||||
is running without its extension. e.g: if your program is called
|
|
||||||
'usr/bin/foobar.py' this method should probably be called with
|
|
||||||
program_name='foobar' However, this is just a convention, you can
|
|
||||||
pass in any string you want, and it will be used as part of the
|
|
||||||
log filename. If you don't pass in anything, the default behavior
|
|
||||||
is as described in the example. In python standard logging mode,
|
|
||||||
the program_name will be prepended with py_ if it is the program_name
|
|
||||||
argument is omitted.
|
|
||||||
log_dir: str|None, the desired log directory.
|
|
||||||
Returns:
|
|
||||||
(log_dir, file_prefix, symlink_prefix)
|
|
||||||
Raises:
|
|
||||||
FileNotFoundError: raised in Python 3 when it cannot find a log directory.
|
|
||||||
OSError: raised in Python 2 when it cannot find a log directory.
|
|
||||||
"""
|
|
||||||
if not program_name:
|
|
||||||
# Strip the extension (foobar.par becomes foobar, and
|
|
||||||
# fubar.py becomes fubar). We do this so that the log
|
|
||||||
# file names are similar to C++ log file names.
|
|
||||||
program_name = os.path.splitext(os.path.basename(sys.argv[0]))[0]
|
|
||||||
|
|
||||||
# Prepend py_ to files so that python code gets a unique file, and
|
|
||||||
# so that C++ libraries do not try to write to the same log files as us.
|
|
||||||
program_name = 'py_%s' % program_name
|
|
||||||
|
|
||||||
actual_log_dir = find_log_dir(log_dir=log_dir)
|
|
||||||
|
|
||||||
try:
|
|
||||||
username = getpass.getuser()
|
|
||||||
except KeyError:
|
|
||||||
# This can happen, e.g. when running under docker w/o passwd file.
|
|
||||||
if hasattr(os, 'getuid'):
|
|
||||||
# Windows doesn't have os.getuid
|
|
||||||
username = str(os.getuid())
|
|
||||||
else:
|
|
||||||
username = 'unknown'
|
|
||||||
hostname = socket.gethostname()
|
|
||||||
file_prefix = '%s.%s.%s.log' % (program_name, hostname, username)
|
|
||||||
|
|
||||||
return actual_log_dir, file_prefix, program_name
|
|
||||||
|
|
||||||
|
|
||||||
class Log():
|
|
||||||
|
|
||||||
log_name = None
|
|
||||||
|
|
||||||
def __init__(self, logger=None):
|
|
||||||
self.logger = logging.getLogger(logger)
|
|
||||||
self.logger.setLevel(logging.DEBUG)
|
|
||||||
|
|
||||||
file_dir = os.getcwd() + '/log'
|
|
||||||
if not os.path.exists(file_dir):
|
|
||||||
os.mkdir(file_dir)
|
|
||||||
self.log_dir = file_dir
|
|
||||||
|
|
||||||
actual_log_dir, file_prefix, symlink_prefix = find_log_dir_and_names(
|
|
||||||
program_name=None, log_dir=self.log_dir)
|
|
||||||
|
|
||||||
basename = '%s.DEBUG.%d' % (file_prefix, os.getpid())
|
|
||||||
filename = os.path.join(actual_log_dir, basename)
|
|
||||||
if Log.log_name is None:
|
|
||||||
Log.log_name = filename
|
|
||||||
|
|
||||||
# Create a symlink to the log file with a canonical name.
|
|
||||||
symlink = os.path.join(actual_log_dir, symlink_prefix + '.DEBUG')
|
|
||||||
try:
|
|
||||||
if os.path.islink(symlink):
|
|
||||||
os.unlink(symlink)
|
|
||||||
os.symlink(os.path.basename(Log.log_name), symlink)
|
|
||||||
except EnvironmentError:
|
|
||||||
# If it fails, we're sad but it's no error. Commonly, this
|
|
||||||
# fails because the symlink was created by another user and so
|
|
||||||
# we can't modify it
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not self.logger.hasHandlers():
|
|
||||||
formatter = logging.Formatter(fmt=FORMAT_STR, datefmt=DATE_FMT_STR)
|
|
||||||
fh = logging.FileHandler(Log.log_name)
|
|
||||||
fh.setLevel(logging.DEBUG)
|
|
||||||
fh.setFormatter(formatter)
|
|
||||||
self.logger.addHandler(fh)
|
|
||||||
|
|
||||||
ch = logging.StreamHandler()
|
|
||||||
ch.setLevel(logging.INFO)
|
|
||||||
ch.setFormatter(formatter)
|
|
||||||
self.logger.addHandler(ch)
|
|
||||||
|
|
||||||
# stop propagate for propagating may print
|
|
||||||
# log multiple times
|
|
||||||
self.logger.propagate = False
|
|
||||||
|
|
||||||
def getlog(self):
|
|
||||||
return self.logger
|
|
||||||
|
|
||||||
|
|
||||||
class Autolog:
|
|
||||||
def __init__(self,
|
|
||||||
batch_size,
|
|
||||||
model_name="DeepSpeech",
|
|
||||||
model_precision="fp32"):
|
|
||||||
import auto_log
|
|
||||||
pid = os.getpid()
|
|
||||||
if (os.environ['CUDA_VISIBLE_DEVICES'].strip() != ''):
|
|
||||||
gpu_id = int(os.environ['CUDA_VISIBLE_DEVICES'].split(',')[0])
|
|
||||||
infer_config = inference.Config()
|
|
||||||
infer_config.enable_use_gpu(100, gpu_id)
|
|
||||||
else:
|
|
||||||
gpu_id = None
|
|
||||||
infer_config = inference.Config()
|
|
||||||
autolog = auto_log.AutoLogger(
|
|
||||||
model_name=model_name,
|
|
||||||
model_precision=model_precision,
|
|
||||||
batch_size=batch_size,
|
|
||||||
data_shape="dynamic",
|
|
||||||
save_path="./output/auto_log.lpg",
|
|
||||||
inference_config=infer_config,
|
|
||||||
pids=pid,
|
|
||||||
process_name=None,
|
|
||||||
gpu_ids=gpu_id,
|
|
||||||
time_keys=['preprocess_time', 'inference_time', 'postprocess_time'],
|
|
||||||
warmup=0)
|
|
||||||
self.autolog = autolog
|
|
||||||
|
|
||||||
def getlog(self):
|
|
||||||
return self.autolog
|
|
||||||
@ -1,30 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
from functools import wraps
|
|
||||||
|
|
||||||
from paddle import distributed as dist
|
|
||||||
|
|
||||||
__all__ = ["rank_zero_only"]
|
|
||||||
|
|
||||||
|
|
||||||
def rank_zero_only(func):
|
|
||||||
@wraps(func)
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
rank = dist.get_rank()
|
|
||||||
if rank != 0:
|
|
||||||
return
|
|
||||||
result = func(*args, **kwargs)
|
|
||||||
return result
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
@ -1,112 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import socket
|
|
||||||
import socketserver
|
|
||||||
import struct
|
|
||||||
import time
|
|
||||||
import wave
|
|
||||||
from time import gmtime
|
|
||||||
from time import strftime
|
|
||||||
|
|
||||||
from deepspeech.frontend.utility import read_manifest
|
|
||||||
|
|
||||||
__all__ = ["socket_send", "warm_up_test", "AsrTCPServer", "AsrRequestHandler"]
|
|
||||||
|
|
||||||
|
|
||||||
def socket_send(server_ip: str, server_port: str, data: bytes):
|
|
||||||
# Connect to server and send data
|
|
||||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
||||||
sock.connect((server_ip, server_port))
|
|
||||||
sent = data
|
|
||||||
sock.sendall(struct.pack('>i', len(sent)) + sent)
|
|
||||||
print('Speech[length=%d] Sent.' % len(sent))
|
|
||||||
# Receive data from the server and shut down
|
|
||||||
received = sock.recv(1024)
|
|
||||||
print("Recognition Results: {}".format(received.decode('utf8')))
|
|
||||||
sock.close()
|
|
||||||
|
|
||||||
|
|
||||||
def warm_up_test(audio_process_handler,
|
|
||||||
manifest_path,
|
|
||||||
num_test_cases,
|
|
||||||
random_seed=0):
|
|
||||||
"""Warming-up test."""
|
|
||||||
manifest = read_manifest(manifest_path)
|
|
||||||
rng = random.Random(random_seed)
|
|
||||||
samples = rng.sample(manifest, num_test_cases)
|
|
||||||
for idx, sample in enumerate(samples):
|
|
||||||
print("Warm-up Test Case %d: %s" % (idx, sample['feat']))
|
|
||||||
start_time = time.time()
|
|
||||||
transcript = audio_process_handler(sample['feat'])
|
|
||||||
finish_time = time.time()
|
|
||||||
print("Response Time: %f, Transcript: %s" %
|
|
||||||
(finish_time - start_time, transcript))
|
|
||||||
|
|
||||||
|
|
||||||
class AsrTCPServer(socketserver.TCPServer):
|
|
||||||
"""The ASR TCP Server."""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
server_address,
|
|
||||||
RequestHandlerClass,
|
|
||||||
speech_save_dir,
|
|
||||||
audio_process_handler,
|
|
||||||
bind_and_activate=True):
|
|
||||||
self.speech_save_dir = speech_save_dir
|
|
||||||
self.audio_process_handler = audio_process_handler
|
|
||||||
socketserver.TCPServer.__init__(
|
|
||||||
self, server_address, RequestHandlerClass, bind_and_activate=True)
|
|
||||||
|
|
||||||
|
|
||||||
class AsrRequestHandler(socketserver.BaseRequestHandler):
|
|
||||||
"""The ASR request handler."""
|
|
||||||
|
|
||||||
def handle(self):
|
|
||||||
# receive data through TCP socket
|
|
||||||
chunk = self.request.recv(1024)
|
|
||||||
target_len = struct.unpack('>i', chunk[:4])[0]
|
|
||||||
data = chunk[4:]
|
|
||||||
while len(data) < target_len:
|
|
||||||
chunk = self.request.recv(1024)
|
|
||||||
data += chunk
|
|
||||||
# write to file
|
|
||||||
filename = self._write_to_file(data)
|
|
||||||
|
|
||||||
print("Received utterance[length=%d] from %s, saved to %s." %
|
|
||||||
(len(data), self.client_address[0], filename))
|
|
||||||
start_time = time.time()
|
|
||||||
transcript = self.server.audio_process_handler(filename)
|
|
||||||
finish_time = time.time()
|
|
||||||
print("Response Time: %f, Transcript: %s" %
|
|
||||||
(finish_time - start_time, transcript))
|
|
||||||
self.request.sendall(transcript.encode('utf-8'))
|
|
||||||
|
|
||||||
def _write_to_file(self, data):
|
|
||||||
# prepare save dir and filename
|
|
||||||
if not os.path.exists(self.server.speech_save_dir):
|
|
||||||
os.mkdir(self.server.speech_save_dir)
|
|
||||||
timestamp = strftime("%Y%m%d%H%M%S", gmtime())
|
|
||||||
out_filename = os.path.join(
|
|
||||||
self.server.speech_save_dir,
|
|
||||||
timestamp + "_" + self.client_address[0] + ".wav")
|
|
||||||
# write to wav file
|
|
||||||
file = wave.open(out_filename, 'wb')
|
|
||||||
file.setnchannels(1)
|
|
||||||
file.setsampwidth(2)
|
|
||||||
file.setframerate(16000)
|
|
||||||
file.writeframes(data)
|
|
||||||
file.close()
|
|
||||||
return out_filename
|
|
||||||
@ -1,180 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Unility functions for Transformer."""
|
|
||||||
from typing import List
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
import paddle
|
|
||||||
|
|
||||||
from deepspeech.utils.log import Log
|
|
||||||
|
|
||||||
__all__ = ["pad_sequence", "add_sos_eos", "th_accuracy", "has_tensor"]
|
|
||||||
|
|
||||||
logger = Log(__name__).getlog()
|
|
||||||
|
|
||||||
|
|
||||||
def has_tensor(val):
|
|
||||||
if isinstance(val, (list, tuple)):
|
|
||||||
for item in val:
|
|
||||||
if has_tensor(item):
|
|
||||||
return True
|
|
||||||
elif isinstance(val, dict):
|
|
||||||
for k, v in val.items():
|
|
||||||
print(k)
|
|
||||||
if has_tensor(v):
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return paddle.is_tensor(val)
|
|
||||||
|
|
||||||
|
|
||||||
def pad_sequence(sequences: List[paddle.Tensor],
|
|
||||||
batch_first: bool=False,
|
|
||||||
padding_value: float=0.0) -> paddle.Tensor:
|
|
||||||
r"""Pad a list of variable length Tensors with ``padding_value``
|
|
||||||
|
|
||||||
``pad_sequence`` stacks a list of Tensors along a new dimension,
|
|
||||||
and pads them to equal length. For example, if the input is list of
|
|
||||||
sequences with size ``L x *`` and if batch_first is False, and ``T x B x *``
|
|
||||||
otherwise.
|
|
||||||
|
|
||||||
`B` is batch size. It is equal to the number of elements in ``sequences``.
|
|
||||||
`T` is length of the longest sequence.
|
|
||||||
`L` is length of the sequence.
|
|
||||||
`*` is any number of trailing dimensions, including none.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
>>> from paddle.nn.utils.rnn import pad_sequence
|
|
||||||
>>> a = paddle.ones(25, 300)
|
|
||||||
>>> b = paddle.ones(22, 300)
|
|
||||||
>>> c = paddle.ones(15, 300)
|
|
||||||
>>> pad_sequence([a, b, c]).size()
|
|
||||||
paddle.Tensor([25, 3, 300])
|
|
||||||
|
|
||||||
Note:
|
|
||||||
This function returns a Tensor of size ``T x B x *`` or ``B x T x *``
|
|
||||||
where `T` is the length of the longest sequence. This function assumes
|
|
||||||
trailing dimensions and type of all the Tensors in sequences are same.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sequences (list[Tensor]): list of variable length sequences.
|
|
||||||
batch_first (bool, optional): output will be in ``B x T x *`` if True, or in
|
|
||||||
``T x B x *`` otherwise
|
|
||||||
padding_value (float, optional): value for padded elements. Default: 0.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor of size ``T x B x *`` if :attr:`batch_first` is ``False``.
|
|
||||||
Tensor of size ``B x T x *`` otherwise
|
|
||||||
"""
|
|
||||||
|
|
||||||
# assuming trailing dimensions and type of all the Tensors
|
|
||||||
# in sequences are same and fetching those from sequences[0]
|
|
||||||
max_size = sequences[0].size()
|
|
||||||
# (TODO Hui Zhang): slice not supprot `end==start`
|
|
||||||
# trailing_dims = max_size[1:]
|
|
||||||
trailing_dims = max_size[1:] if max_size.ndim >= 2 else ()
|
|
||||||
max_len = max([s.size(0) for s in sequences])
|
|
||||||
if batch_first:
|
|
||||||
out_dims = (len(sequences), max_len) + trailing_dims
|
|
||||||
else:
|
|
||||||
out_dims = (max_len, len(sequences)) + trailing_dims
|
|
||||||
|
|
||||||
out_tensor = sequences[0].new_full(out_dims, padding_value)
|
|
||||||
for i, tensor in enumerate(sequences):
|
|
||||||
length = tensor.size(0)
|
|
||||||
# use index notation to prevent duplicate references to the tensor
|
|
||||||
if batch_first:
|
|
||||||
out_tensor[i, :length, ...] = tensor
|
|
||||||
else:
|
|
||||||
out_tensor[:length, i, ...] = tensor
|
|
||||||
|
|
||||||
return out_tensor
|
|
||||||
|
|
||||||
|
|
||||||
def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int,
|
|
||||||
ignore_id: int) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
|
||||||
"""Add <sos> and <eos> labels.
|
|
||||||
Args:
|
|
||||||
ys_pad (paddle.Tensor): batch of padded target sequences (B, Lmax)
|
|
||||||
sos (int): index of <sos>
|
|
||||||
eos (int): index of <eeos>
|
|
||||||
ignore_id (int): index of padding
|
|
||||||
Returns:
|
|
||||||
ys_in (paddle.Tensor) : (B, Lmax + 1)
|
|
||||||
ys_out (paddle.Tensor) : (B, Lmax + 1)
|
|
||||||
Examples:
|
|
||||||
>>> sos_id = 10
|
|
||||||
>>> eos_id = 11
|
|
||||||
>>> ignore_id = -1
|
|
||||||
>>> ys_pad
|
|
||||||
tensor([[ 1, 2, 3, 4, 5],
|
|
||||||
[ 4, 5, 6, -1, -1],
|
|
||||||
[ 7, 8, 9, -1, -1]], dtype=paddle.int32)
|
|
||||||
>>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)
|
|
||||||
>>> ys_in
|
|
||||||
tensor([[10, 1, 2, 3, 4, 5],
|
|
||||||
[10, 4, 5, 6, 11, 11],
|
|
||||||
[10, 7, 8, 9, 11, 11]])
|
|
||||||
>>> ys_out
|
|
||||||
tensor([[ 1, 2, 3, 4, 5, 11],
|
|
||||||
[ 4, 5, 6, 11, -1, -1],
|
|
||||||
[ 7, 8, 9, 11, -1, -1]])
|
|
||||||
"""
|
|
||||||
# TODO(Hui Zhang): using comment code,
|
|
||||||
#_sos = paddle.to_tensor(
|
|
||||||
# [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
|
|
||||||
#_eos = paddle.to_tensor(
|
|
||||||
# [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place)
|
|
||||||
#ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
|
|
||||||
#ys_in = [paddle.cat([_sos, y], dim=0) for y in ys]
|
|
||||||
#ys_out = [paddle.cat([y, _eos], dim=0) for y in ys]
|
|
||||||
#return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id)
|
|
||||||
B = ys_pad.size(0)
|
|
||||||
_sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos
|
|
||||||
_eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos
|
|
||||||
ys_in = paddle.cat([_sos, ys_pad], dim=1)
|
|
||||||
mask_pad = (ys_in == ignore_id)
|
|
||||||
ys_in = ys_in.masked_fill(mask_pad, eos)
|
|
||||||
|
|
||||||
ys_out = paddle.cat([ys_pad, _eos], dim=1)
|
|
||||||
ys_out = ys_out.masked_fill(mask_pad, eos)
|
|
||||||
mask_eos = (ys_out == ignore_id)
|
|
||||||
ys_out = ys_out.masked_fill(mask_eos, eos)
|
|
||||||
ys_out = ys_out.masked_fill(mask_pad, ignore_id)
|
|
||||||
return ys_in, ys_out
|
|
||||||
|
|
||||||
|
|
||||||
def th_accuracy(pad_outputs: paddle.Tensor,
|
|
||||||
pad_targets: paddle.Tensor,
|
|
||||||
ignore_label: int) -> float:
|
|
||||||
"""Calculate accuracy.
|
|
||||||
Args:
|
|
||||||
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
|
||||||
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
|
|
||||||
ignore_label (int): Ignore label id.
|
|
||||||
Returns:
|
|
||||||
float: Accuracy value (0.0 - 1.0).
|
|
||||||
"""
|
|
||||||
pad_pred = pad_outputs.view(
|
|
||||||
pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)).argmax(2)
|
|
||||||
mask = pad_targets != ignore_label
|
|
||||||
#TODO(Hui Zhang): sum not support bool type
|
|
||||||
# numerator = paddle.sum(
|
|
||||||
# pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
|
|
||||||
numerator = (
|
|
||||||
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
|
|
||||||
numerator = paddle.sum(numerator.type_as(pad_targets))
|
|
||||||
#TODO(Hui Zhang): sum not support bool type
|
|
||||||
# denominator = paddle.sum(mask)
|
|
||||||
denominator = paddle.sum(mask.type_as(pad_targets))
|
|
||||||
return float(numerator) / float(denominator)
|
|
||||||
@ -1,127 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
from typing import Dict
|
|
||||||
from typing import List
|
|
||||||
from typing import Text
|
|
||||||
|
|
||||||
import textgrid
|
|
||||||
|
|
||||||
|
|
||||||
def segment_alignment(alignment: List[int], blank_id=0) -> List[List[int]]:
|
|
||||||
"""segment ctc alignment ids by continuous blank and repeat label.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
alignment (List[int]): ctc alignment id sequence.
|
|
||||||
e.g. [0, 0, 0, 1, 1, 1, 2, 0, 0, 3]
|
|
||||||
blank_id (int, optional): blank id. Defaults to 0.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[List[int]]: token align, segment aligment id sequence.
|
|
||||||
e.g. [[0, 0, 0, 1, 1, 1], [2], [0, 0, 3]]
|
|
||||||
"""
|
|
||||||
# convert alignment to a praat format, which is a doing phonetics
|
|
||||||
# by computer and helps analyzing alignment
|
|
||||||
align_segs = []
|
|
||||||
# get frames level duration for each token
|
|
||||||
start = 0
|
|
||||||
end = 0
|
|
||||||
while end < len(alignment):
|
|
||||||
while end < len(alignment) and alignment[end] == blank_id: # blank
|
|
||||||
end += 1
|
|
||||||
if end == len(alignment):
|
|
||||||
align_segs[-1].extend(alignment[start:])
|
|
||||||
break
|
|
||||||
end += 1
|
|
||||||
while end < len(alignment) and alignment[end - 1] == alignment[
|
|
||||||
end]: # repeat label
|
|
||||||
end += 1
|
|
||||||
align_segs.append(alignment[start:end])
|
|
||||||
start = end
|
|
||||||
return align_segs
|
|
||||||
|
|
||||||
|
|
||||||
def align_to_tierformat(align_segs: List[List[int]],
|
|
||||||
subsample: int,
|
|
||||||
token_dict: Dict[int, Text],
|
|
||||||
blank_id=0) -> List[Text]:
|
|
||||||
"""Generate textgrid.Interval format from alignment segmentations.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
align_segs (List[List[int]]): segmented ctc alignment ids.
|
|
||||||
subsample (int): 25ms frame_length, 10ms hop_length, 1/subsample
|
|
||||||
token_dict (Dict[int, Text]): int -> str map.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[Text]: list of textgrid.Interval text, str(start, end, text).
|
|
||||||
"""
|
|
||||||
hop_length = 10 # ms
|
|
||||||
second_ms = 1000 # ms
|
|
||||||
frame_per_second = second_ms / hop_length # 25ms frame_length, 10ms hop_length
|
|
||||||
second_per_frame = 1.0 / frame_per_second
|
|
||||||
|
|
||||||
begin = 0
|
|
||||||
duration = 0
|
|
||||||
tierformat = []
|
|
||||||
|
|
||||||
for idx, tokens in enumerate(align_segs):
|
|
||||||
token_len = len(tokens)
|
|
||||||
token = tokens[-1]
|
|
||||||
# time duration in second
|
|
||||||
duration = token_len * subsample * second_per_frame
|
|
||||||
if idx < len(align_segs) - 1:
|
|
||||||
print(f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}")
|
|
||||||
tierformat.append(
|
|
||||||
f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}\n")
|
|
||||||
else:
|
|
||||||
for i in tokens:
|
|
||||||
if i != blank_id:
|
|
||||||
token = i
|
|
||||||
break
|
|
||||||
print(f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}")
|
|
||||||
tierformat.append(
|
|
||||||
f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}\n")
|
|
||||||
begin = begin + duration
|
|
||||||
|
|
||||||
return tierformat
|
|
||||||
|
|
||||||
|
|
||||||
def generate_textgrid(maxtime: float,
|
|
||||||
intervals: List[Text],
|
|
||||||
output: Text,
|
|
||||||
name: Text='ali') -> None:
|
|
||||||
"""Create alignment textgrid file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
maxtime (float): audio duartion.
|
|
||||||
intervals (List[Text]): ctc output alignment. e.g. "start-time end-time word" per item.
|
|
||||||
output (Text): textgrid filepath.
|
|
||||||
name (Text, optional): tier or layer name. Defaults to 'ali'.
|
|
||||||
"""
|
|
||||||
# Download Praat: https://www.fon.hum.uva.nl/praat/
|
|
||||||
avg_interval = maxtime / (len(intervals) + 1)
|
|
||||||
print(f"average second/token: {avg_interval}")
|
|
||||||
margin = 0.0001
|
|
||||||
|
|
||||||
tg = textgrid.TextGrid(maxTime=maxtime)
|
|
||||||
tier = textgrid.IntervalTier(name=name, maxTime=maxtime)
|
|
||||||
|
|
||||||
i = 0
|
|
||||||
for dur in intervals:
|
|
||||||
s, e, text = dur.split()
|
|
||||||
tier.add(minTime=float(s) + margin, maxTime=float(e), mark=text)
|
|
||||||
|
|
||||||
tg.append(tier)
|
|
||||||
|
|
||||||
tg.write(output)
|
|
||||||
print("successfully generator textgrid {}.".format(output))
|
|
||||||
@ -1,110 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Contains common utility functions."""
|
|
||||||
import distutils.util
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import paddle
|
|
||||||
|
|
||||||
__all__ = ["seed_all", 'print_arguments', 'add_arguments', "log_add"]
|
|
||||||
|
|
||||||
|
|
||||||
def seed_all(seed: int=210329):
|
|
||||||
np.random.seed(seed)
|
|
||||||
random.seed(seed)
|
|
||||||
paddle.seed(seed)
|
|
||||||
|
|
||||||
|
|
||||||
def print_arguments(args, info=None):
|
|
||||||
"""Print argparse's arguments.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("name", default="Jonh", type=str, help="User name.")
|
|
||||||
args = parser.parse_args()
|
|
||||||
print_arguments(args)
|
|
||||||
|
|
||||||
:param args: Input argparse.Namespace for printing.
|
|
||||||
:type args: argparse.Namespace
|
|
||||||
"""
|
|
||||||
filename = ""
|
|
||||||
if info:
|
|
||||||
filename = info["__file__"]
|
|
||||||
filename = os.path.basename(filename)
|
|
||||||
print(f"----------- {filename} Configuration Arguments -----------")
|
|
||||||
for arg, value in sorted(vars(args).items()):
|
|
||||||
print("%s: %s" % (arg, value))
|
|
||||||
print("-----------------------------------------------------------")
|
|
||||||
|
|
||||||
|
|
||||||
def add_arguments(argname, type, default, help, argparser, **kwargs):
|
|
||||||
"""Add argparse's argument.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
add_argument("name", str, "Jonh", "User name.", parser)
|
|
||||||
args = parser.parse_args()
|
|
||||||
"""
|
|
||||||
type = distutils.util.strtobool if type == bool else type
|
|
||||||
argparser.add_argument(
|
|
||||||
"--" + argname,
|
|
||||||
default=default,
|
|
||||||
type=type,
|
|
||||||
help=help + ' Default: %(default)s.',
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def log_add(args: List[int]) -> float:
|
|
||||||
"""Stable log add
|
|
||||||
|
|
||||||
Args:
|
|
||||||
args (List[int]): log scores
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
float: sum of log scores
|
|
||||||
"""
|
|
||||||
if all(a == -float('inf') for a in args):
|
|
||||||
return -float('inf')
|
|
||||||
a_max = max(args)
|
|
||||||
lsp = math.log(sum(math.exp(a - a_max) for a in args))
|
|
||||||
return a_max + lsp
|
|
||||||
|
|
||||||
|
|
||||||
def get_subsample(config):
|
|
||||||
"""Subsample rate from config.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (yacs.config.CfgNode): yaml config
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: subsample rate.
|
|
||||||
"""
|
|
||||||
input_layer = config["model"]["encoder_conf"]["input_layer"]
|
|
||||||
assert input_layer in ["conv2d", "conv2d6", "conv2d8"]
|
|
||||||
if input_layer == "conv2d":
|
|
||||||
return 4
|
|
||||||
elif input_layer == "conv2d6":
|
|
||||||
return 6
|
|
||||||
elif input_layer == "conv2d8":
|
|
||||||
return 8
|
|
||||||
@ -1,42 +0,0 @@
|
|||||||
#! /usr/bin/env bash
|
|
||||||
|
|
||||||
cd ../.. > /dev/null
|
|
||||||
|
|
||||||
# download data, generate manifests
|
|
||||||
PYTHONPATH=.:$PYTHONPATH python3 data/aishell/aishell.py \
|
|
||||||
--manifest_prefix='data/aishell/manifest' \
|
|
||||||
--target_dir='../dataset/aishell'
|
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "Prepare Aishell failed. Terminated."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
# build vocabulary
|
|
||||||
python3 tools/build_vocab.py \
|
|
||||||
--count_threshold=0 \
|
|
||||||
--vocab_path='data/aishell/vocab.txt' \
|
|
||||||
--manifest_paths 'data/aishell/manifest.train' 'data/aishell/manifest.dev'
|
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "Build vocabulary failed. Terminated."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
# compute mean and stddev for normalizer
|
|
||||||
python3 tools/compute_mean_std.py \
|
|
||||||
--manifest_path='data/aishell/manifest.train' \
|
|
||||||
--num_samples=2000 \
|
|
||||||
--specgram_type='linear' \
|
|
||||||
--output_path='data/aishell/mean_std.npz'
|
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "Compute mean and stddev failed. Terminated."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
echo "Aishell data preparation done."
|
|
||||||
exit 0
|
|
||||||
@ -1,55 +0,0 @@
|
|||||||
#! /usr/bin/env bash
|
|
||||||
|
|
||||||
cd ../.. > /dev/null
|
|
||||||
|
|
||||||
# download language model
|
|
||||||
cd models/lm > /dev/null
|
|
||||||
bash download_lm_ch.sh
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
cd - > /dev/null
|
|
||||||
|
|
||||||
|
|
||||||
# download well-trained model
|
|
||||||
cd models/aishell > /dev/null
|
|
||||||
bash download_model.sh
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
cd - > /dev/null
|
|
||||||
|
|
||||||
|
|
||||||
# infer
|
|
||||||
CUDA_VISIBLE_DEVICES=0 \
|
|
||||||
python3 -u infer2x.py \
|
|
||||||
--num_samples=10 \
|
|
||||||
--beam_size=300 \
|
|
||||||
--feat_dim=161 \
|
|
||||||
--num_proc_bsearch=8 \
|
|
||||||
--num_conv_layers=2 \
|
|
||||||
--num_rnn_layers=3 \
|
|
||||||
--rnn_layer_size=1024 \
|
|
||||||
--alpha=2.6 \
|
|
||||||
--beta=5.0 \
|
|
||||||
--cutoff_prob=0.99 \
|
|
||||||
--cutoff_top_n=40 \
|
|
||||||
--use_gru=True \
|
|
||||||
--use_gpu=False \
|
|
||||||
--share_rnn_weights=False \
|
|
||||||
--infer_manifest='data/aishell/manifest.test' \
|
|
||||||
--mean_std_path='models/aishell/mean_std.npz' \
|
|
||||||
--vocab_path='models/aishell/vocab.txt' \
|
|
||||||
--model_path='models/aishell/aishell_v1.8.pdparams' \
|
|
||||||
--lang_model_path='models/lm/zh_giga.no_cna_cmn.prune01244.klm' \
|
|
||||||
--decoding_method='ctc_beam_search' \
|
|
||||||
--error_rate_type='cer' \
|
|
||||||
--specgram_type='linear'
|
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "Failed in inference!"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
exit 0
|
|
||||||
@ -1,54 +0,0 @@
|
|||||||
#! /usr/bin/env bash
|
|
||||||
|
|
||||||
cd ../.. > /dev/null
|
|
||||||
|
|
||||||
# download language model
|
|
||||||
cd models/lm > /dev/null
|
|
||||||
bash download_lm_ch.sh
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
cd - > /dev/null
|
|
||||||
|
|
||||||
# download well-trained model
|
|
||||||
cd models/aishell > /dev/null
|
|
||||||
bash download_model.sh
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
cd - > /dev/null
|
|
||||||
|
|
||||||
|
|
||||||
# evaluate model
|
|
||||||
CUDA_VISIBLE_DEVICES=1 \
|
|
||||||
python3 -u test2x.py \
|
|
||||||
--batch_size=64 \
|
|
||||||
--beam_size=300 \
|
|
||||||
--feat_dim=161 \
|
|
||||||
--num_proc_bsearch=8 \
|
|
||||||
--num_conv_layers=2 \
|
|
||||||
--num_rnn_layers=3 \
|
|
||||||
--rnn_layer_size=1024 \
|
|
||||||
--alpha=2.6 \
|
|
||||||
--beta=5.0 \
|
|
||||||
--cutoff_prob=0.99 \
|
|
||||||
--cutoff_top_n=40 \
|
|
||||||
--use_gru=True \
|
|
||||||
--use_gpu=True \
|
|
||||||
--share_rnn_weights=False \
|
|
||||||
--test_manifest='data/aishell/manifest.test' \
|
|
||||||
--mean_std_path='models/aishell/mean_std.npz' \
|
|
||||||
--vocab_path='models/aishell/vocab.txt' \
|
|
||||||
--model_path='models/aishell/aishell_v1.8.pdparams' \
|
|
||||||
--lang_model_path='models/lm/zh_giga.no_cna_cmn.prune01244.klm' \
|
|
||||||
--decoding_method='ctc_beam_search' \
|
|
||||||
--error_rate_type='cer' \
|
|
||||||
--specgram_type='linear'
|
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "Failed in evaluation!"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
exit 0
|
|
||||||
@ -1,45 +0,0 @@
|
|||||||
#! /usr/bin/env bash
|
|
||||||
|
|
||||||
cd ../.. > /dev/null
|
|
||||||
|
|
||||||
# download data, generate manifests
|
|
||||||
PYTHONPATH=.:$PYTHONPATH python3 data/librispeech/librispeech.py \
|
|
||||||
--manifest_prefix='data/librispeech/manifest' \
|
|
||||||
--target_dir='../dataset/librispeech' \
|
|
||||||
--full_download='True'
|
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "Prepare LibriSpeech failed. Terminated."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
cat data/librispeech/manifest.train-* | shuf > data/librispeech/manifest.train
|
|
||||||
|
|
||||||
|
|
||||||
# build vocabulary
|
|
||||||
python3 tools/build_vocab.py \
|
|
||||||
--count_threshold=0 \
|
|
||||||
--vocab_path='data/librispeech/vocab.txt' \
|
|
||||||
--manifest_paths='data/librispeech/manifest.train'
|
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "Build vocabulary failed. Terminated."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
# compute mean and stddev for normalizer
|
|
||||||
python3 tools/compute_mean_std.py \
|
|
||||||
--manifest_path='data/librispeech/manifest.train' \
|
|
||||||
--num_samples=2000 \
|
|
||||||
--specgram_type='linear' \
|
|
||||||
--output_path='data/librispeech/mean_std.npz'
|
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "Compute mean and stddev failed. Terminated."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
echo "LibriSpeech Data preparation done."
|
|
||||||
exit 0
|
|
||||||
@ -1,55 +0,0 @@
|
|||||||
#! /usr/bin/env bash
|
|
||||||
|
|
||||||
cd ../.. > /dev/null
|
|
||||||
|
|
||||||
# download language model
|
|
||||||
cd models/lm > /dev/null
|
|
||||||
bash download_lm_en.sh
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
cd - > /dev/null
|
|
||||||
|
|
||||||
|
|
||||||
# download well-trained model
|
|
||||||
cd models/baidu_en8k > /dev/null
|
|
||||||
bash download_model.sh
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
cd - > /dev/null
|
|
||||||
|
|
||||||
|
|
||||||
# infer
|
|
||||||
CUDA_VISIBLE_DEVICES=0 \
|
|
||||||
python3 -u infer2x.py \
|
|
||||||
--num_samples=10 \
|
|
||||||
--beam_size=500 \
|
|
||||||
--feat_dim=161 \
|
|
||||||
--num_proc_bsearch=5 \
|
|
||||||
--num_conv_layers=2 \
|
|
||||||
--num_rnn_layers=3 \
|
|
||||||
--rnn_layer_size=1024 \
|
|
||||||
--alpha=1.4 \
|
|
||||||
--beta=0.35 \
|
|
||||||
--cutoff_prob=1.0 \
|
|
||||||
--cutoff_top_n=40 \
|
|
||||||
--use_gru=True \
|
|
||||||
--use_gpu=False \
|
|
||||||
--share_rnn_weights=False \
|
|
||||||
--infer_manifest='data/librispeech/manifest.test-clean' \
|
|
||||||
--mean_std_path='models/baidu_en8k/mean_std.npz' \
|
|
||||||
--vocab_path='models/baidu_en8k/vocab.txt' \
|
|
||||||
--model_path='models/baidu_en8k/baidu_en8k_v1.8.pdparams' \
|
|
||||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
|
||||||
--decoding_method='ctc_beam_search' \
|
|
||||||
--error_rate_type='wer' \
|
|
||||||
--specgram_type='linear'
|
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "Failed in inference!"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
exit 0
|
|
||||||
@ -1,55 +0,0 @@
|
|||||||
#! /usr/bin/env bash
|
|
||||||
|
|
||||||
cd ../.. > /dev/null
|
|
||||||
|
|
||||||
# download language model
|
|
||||||
cd models/lm > /dev/null
|
|
||||||
bash download_lm_en.sh
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
cd - > /dev/null
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# download well-trained model
|
|
||||||
cd models/baidu_en8k > /dev/null
|
|
||||||
bash download_model.sh
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
cd - > /dev/null
|
|
||||||
|
|
||||||
|
|
||||||
# evaluate model
|
|
||||||
CUDA_VISIBLE_DEVICES=0 \
|
|
||||||
python3 -u test2x.py \
|
|
||||||
--batch_size=32 \
|
|
||||||
--beam_size=500 \
|
|
||||||
--feat_dim=161 \
|
|
||||||
--num_proc_bsearch=8 \
|
|
||||||
--num_conv_layers=2 \
|
|
||||||
--num_rnn_layers=3 \
|
|
||||||
--rnn_layer_size=1024 \
|
|
||||||
--alpha=1.4 \
|
|
||||||
--beta=0.35 \
|
|
||||||
--cutoff_prob=1.0 \
|
|
||||||
--cutoff_top_n=40 \
|
|
||||||
--use_gru=True \
|
|
||||||
--use_gpu=False \
|
|
||||||
--share_rnn_weights=False \
|
|
||||||
--test_manifest='data/librispeech/manifest.test-clean' \
|
|
||||||
--mean_std_path='models/baidu_en8k/mean_std.npz' \
|
|
||||||
--vocab_path='models/baidu_en8k/vocab.txt' \
|
|
||||||
--model_path='models/baidu_en8k/baidu_en8k_v1.8.pdparams' \
|
|
||||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
|
||||||
--decoding_method='ctc_beam_search' \
|
|
||||||
--error_rate_type='wer' \
|
|
||||||
--specgram_type='linear'
|
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "Failed in evaluation!"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
exit 0
|
|
||||||
@ -1,45 +0,0 @@
|
|||||||
#! /usr/bin/env bash
|
|
||||||
|
|
||||||
cd ../.. > /dev/null
|
|
||||||
|
|
||||||
# download data, generate manifests
|
|
||||||
PYTHONPATH=.:$PYTHONPATH python3 data/librispeech/librispeech.py \
|
|
||||||
--manifest_prefix='data/librispeech/manifest' \
|
|
||||||
--target_dir='../dataset/librispeech' \
|
|
||||||
--full_download='True'
|
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "Prepare LibriSpeech failed. Terminated."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
cat data/librispeech/manifest.train-* | shuf > data/librispeech/manifest.train
|
|
||||||
|
|
||||||
|
|
||||||
# build vocabulary
|
|
||||||
python3 tools/build_vocab.py \
|
|
||||||
--count_threshold=0 \
|
|
||||||
--vocab_path='data/librispeech/vocab.txt' \
|
|
||||||
--manifest_paths='data/librispeech/manifest.train'
|
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "Build vocabulary failed. Terminated."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
# compute mean and stddev for normalizer
|
|
||||||
python3 tools/compute_mean_std.py \
|
|
||||||
--manifest_path='data/librispeech/manifest.train' \
|
|
||||||
--num_samples=2000 \
|
|
||||||
--specgram_type='linear' \
|
|
||||||
--output_path='data/librispeech/mean_std.npz'
|
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "Compute mean and stddev failed. Terminated."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
echo "LibriSpeech Data preparation done."
|
|
||||||
exit 0
|
|
||||||
@ -1,55 +0,0 @@
|
|||||||
#! /usr/bin/env bash
|
|
||||||
|
|
||||||
cd ../.. > /dev/null
|
|
||||||
|
|
||||||
# download language model
|
|
||||||
cd models/lm > /dev/null
|
|
||||||
bash download_lm_en.sh
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
cd - > /dev/null
|
|
||||||
|
|
||||||
|
|
||||||
# download well-trained model
|
|
||||||
cd models/librispeech > /dev/null
|
|
||||||
bash download_model.sh
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
cd - > /dev/null
|
|
||||||
|
|
||||||
|
|
||||||
# infer
|
|
||||||
CUDA_VISIBLE_DEVICES=0 \
|
|
||||||
python3 -u infer2x.py \
|
|
||||||
--num_samples=10 \
|
|
||||||
--beam_size=500 \
|
|
||||||
--feat_dim=161 \
|
|
||||||
--num_proc_bsearch=8 \
|
|
||||||
--num_conv_layers=2 \
|
|
||||||
--num_rnn_layers=3 \
|
|
||||||
--rnn_layer_size=2048 \
|
|
||||||
--alpha=2.5 \
|
|
||||||
--beta=0.3 \
|
|
||||||
--cutoff_prob=1.0 \
|
|
||||||
--cutoff_top_n=40 \
|
|
||||||
--use_gru=False \
|
|
||||||
--use_gpu=True \
|
|
||||||
--share_rnn_weights=True \
|
|
||||||
--infer_manifest='data/librispeech/manifest.test-clean' \
|
|
||||||
--mean_std_path='models/librispeech/mean_std.npz' \
|
|
||||||
--vocab_path='models/librispeech/vocab.txt' \
|
|
||||||
--model_path='models/librispeech/librispeech_v1.8.pdparams' \
|
|
||||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
|
||||||
--decoding_method='ctc_beam_search' \
|
|
||||||
--error_rate_type='wer' \
|
|
||||||
--specgram_type='linear'
|
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "Failed in inference!"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
exit 0
|
|
||||||
@ -1,55 +0,0 @@
|
|||||||
#! /usr/bin/env bash
|
|
||||||
|
|
||||||
cd ../.. > /dev/null
|
|
||||||
|
|
||||||
# download language model
|
|
||||||
cd models/lm > /dev/null
|
|
||||||
bash download_lm_en.sh
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
cd - > /dev/null
|
|
||||||
|
|
||||||
|
|
||||||
# download well-trained model
|
|
||||||
cd models/librispeech > /dev/null
|
|
||||||
bash download_model.sh
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
cd - > /dev/null
|
|
||||||
|
|
||||||
|
|
||||||
# evaluate model
|
|
||||||
CUDA_VISIBLE_DEVICES=0 \
|
|
||||||
python3 -u test2x.py \
|
|
||||||
--batch_size=32 \
|
|
||||||
--beam_size=500 \
|
|
||||||
--feat_dim=161 \
|
|
||||||
--num_proc_bsearch=8 \
|
|
||||||
--num_conv_layers=2 \
|
|
||||||
--num_rnn_layers=3 \
|
|
||||||
--rnn_layer_size=2048 \
|
|
||||||
--alpha=2.5 \
|
|
||||||
--beta=0.3 \
|
|
||||||
--cutoff_prob=1.0 \
|
|
||||||
--cutoff_top_n=40 \
|
|
||||||
--use_gru=False \
|
|
||||||
--use_gpu=True \
|
|
||||||
--share_rnn_weights=True \
|
|
||||||
--test_manifest='data/librispeech/manifest.test-clean' \
|
|
||||||
--mean_std_path='models/librispeech/mean_std.npz' \
|
|
||||||
--vocab_path='models/librispeech/vocab.txt' \
|
|
||||||
--model_path='models/librispeech/librispeech_v1.8.pdparams' \
|
|
||||||
--lang_model_path='models/lm/common_crawl_00.prune01111.trie.klm' \
|
|
||||||
--decoding_method='ctc_beam_search' \
|
|
||||||
--error_rate_type='wer' \
|
|
||||||
--specgram_type='linear'
|
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "Failed in evaluation!"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
exit 0
|
|
||||||
@ -1,163 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Inferer for DeepSpeech2 model."""
|
|
||||||
import argparse
|
|
||||||
import functools
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import paddle
|
|
||||||
import paddle.fluid as fluid
|
|
||||||
from data_utils.data import DataGenerator
|
|
||||||
from model_utils.model_check import check_cuda
|
|
||||||
from model_utils.model_check import check_version
|
|
||||||
|
|
||||||
from deepspeech.models.ds2 import DeepSpeech2Model as DS2
|
|
||||||
from utils.error_rate import cer
|
|
||||||
from utils.error_rate import wer
|
|
||||||
from utils.utility import add_arguments
|
|
||||||
from utils.utility import print_arguments
|
|
||||||
parser = argparse.ArgumentParser(description=__doc__)
|
|
||||||
add_arg = functools.partial(add_arguments, argparser=parser)
|
|
||||||
# yapf: disable
|
|
||||||
add_arg('num_samples', int, 10, "# of samples to infer.")
|
|
||||||
add_arg('beam_size', int, 500, "Beam search width.")
|
|
||||||
add_arg('feat_dim', int, 161, "Feature dim.")
|
|
||||||
add_arg('num_proc_bsearch', int, 8, "# of CPUs for beam search.")
|
|
||||||
add_arg('num_conv_layers', int, 2, "# of convolution layers.")
|
|
||||||
add_arg('num_rnn_layers', int, 3, "# of recurrent layers.")
|
|
||||||
add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.")
|
|
||||||
add_arg('alpha', float, 2.5, "Coef of LM for beam search.")
|
|
||||||
add_arg('beta', float, 0.3, "Coef of WC for beam search.")
|
|
||||||
add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.")
|
|
||||||
add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.")
|
|
||||||
add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.")
|
|
||||||
add_arg('use_gpu', bool, True, "Use GPU or not.")
|
|
||||||
add_arg('share_rnn_weights', bool, True, "Share input-hidden weights across bi-directional RNNs. Not for GRU.")
|
|
||||||
add_arg('infer_manifest', str,
|
|
||||||
'data/librispeech/manifest.dev-clean',
|
|
||||||
"Filepath of manifest to infer.")
|
|
||||||
add_arg('mean_std_path', str,
|
|
||||||
'data/librispeech/mean_std.npz',
|
|
||||||
"Filepath of normalizer's mean & std.")
|
|
||||||
add_arg('vocab_path', str,
|
|
||||||
'data/librispeech/vocab.txt',
|
|
||||||
"Filepath of vocabulary.")
|
|
||||||
add_arg('lang_model_path', str,
|
|
||||||
'models/lm/common_crawl_00.prune01111.trie.klm',
|
|
||||||
"Filepath for language model.")
|
|
||||||
add_arg('model_path', str,
|
|
||||||
'./checkpoints/libri/step_final',
|
|
||||||
"If None, the training starts from scratch, "
|
|
||||||
"otherwise, it resumes from the pre-trained model.")
|
|
||||||
add_arg('decoding_method', str,
|
|
||||||
'ctc_beam_search',
|
|
||||||
"Decoding method. Options: ctc_beam_search, ctc_greedy",
|
|
||||||
choices=['ctc_beam_search', 'ctc_greedy'])
|
|
||||||
add_arg('error_rate_type', str,
|
|
||||||
'wer',
|
|
||||||
"Error rate type for evaluation.",
|
|
||||||
choices=['wer', 'cer'])
|
|
||||||
add_arg('specgram_type', str,
|
|
||||||
'linear',
|
|
||||||
"Audio feature type. Options: linear, mfcc.",
|
|
||||||
choices=['linear', 'mfcc'])
|
|
||||||
# yapf: disable
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def infer():
|
|
||||||
"""Inference for DeepSpeech2."""
|
|
||||||
|
|
||||||
# check if set use_gpu=True in paddlepaddle cpu version
|
|
||||||
check_cuda(args.use_gpu)
|
|
||||||
# check if paddlepaddle version is satisfied
|
|
||||||
check_version()
|
|
||||||
|
|
||||||
if args.use_gpu:
|
|
||||||
place = fluid.CUDAPlace(0)
|
|
||||||
else:
|
|
||||||
place = fluid.CPUPlace()
|
|
||||||
|
|
||||||
data_generator = DataGenerator(
|
|
||||||
vocab_filepath=args.vocab_path,
|
|
||||||
mean_std_filepath=args.mean_std_path,
|
|
||||||
augmentation_config='{}',
|
|
||||||
specgram_type=args.specgram_type,
|
|
||||||
keep_transcription_text=True,
|
|
||||||
place=place,
|
|
||||||
is_training=False)
|
|
||||||
batch_reader = data_generator.batch_reader_creator(
|
|
||||||
manifest_path=args.infer_manifest,
|
|
||||||
batch_size=args.num_samples,
|
|
||||||
sortagrad=False,
|
|
||||||
shuffle_method=None)
|
|
||||||
|
|
||||||
# decoders only accept string encoded in utf-8
|
|
||||||
vocab_list = [chars for chars in data_generator.vocab_list]
|
|
||||||
for i, char in enumerate(vocab_list):
|
|
||||||
if vocab_list[i] == '':
|
|
||||||
vocab_list[i] = " "
|
|
||||||
|
|
||||||
model = DS2(
|
|
||||||
feat_size=args.feat_dim,
|
|
||||||
dict_size=len(vocab_list),
|
|
||||||
num_conv_layers=args.num_conv_layers,
|
|
||||||
num_rnn_layers=args.num_rnn_layers,
|
|
||||||
rnn_size=args.rnn_layer_size,
|
|
||||||
use_gru=args.use_gru,
|
|
||||||
share_rnn_weights=args.share_rnn_weights,
|
|
||||||
blank_id=len(vocab_list) - 1
|
|
||||||
)
|
|
||||||
params_path = args.model_path
|
|
||||||
model_dict = paddle.load(params_path)
|
|
||||||
model.set_state_dict(model_dict)
|
|
||||||
model.eval()
|
|
||||||
error_rate_func = cer if args.error_rate_type == 'cer' else wer
|
|
||||||
print("start inference ...")
|
|
||||||
for infer_data in batch_reader():
|
|
||||||
target_transcripts = infer_data[1]
|
|
||||||
audio, target_transcripts, audio_len, mask = infer_data
|
|
||||||
audio = np.transpose(audio, (0, 2, 1))
|
|
||||||
audio_len = audio_len.reshape(-1)
|
|
||||||
audio = paddle.to_tensor(audio)
|
|
||||||
audio_len = paddle.to_tensor(audio_len)
|
|
||||||
|
|
||||||
result_transcripts = model.decode(
|
|
||||||
audio=audio,
|
|
||||||
audio_len=audio_len,
|
|
||||||
lang_model_path=args.lang_model_path,
|
|
||||||
decoding_method=args.decoding_method,
|
|
||||||
beam_alpha=args.alpha,
|
|
||||||
beam_beta=args.beta,
|
|
||||||
beam_size=args.beam_size,
|
|
||||||
cutoff_prob=args.cutoff_prob,
|
|
||||||
cutoff_top_n=args.cutoff_top_n,
|
|
||||||
vocab_list=vocab_list,
|
|
||||||
num_processes=args.num_proc_bsearch
|
|
||||||
)
|
|
||||||
for target, result in zip(target_transcripts, result_transcripts):
|
|
||||||
print("\nTarget Transcription: %s\nOutput Transcription: %s" %
|
|
||||||
(target, result))
|
|
||||||
print("Current error rate [%s] = %f" %
|
|
||||||
(args.error_rate_type, error_rate_func(target, result)))
|
|
||||||
|
|
||||||
print("finish inference")
|
|
||||||
|
|
||||||
def main():
|
|
||||||
print_arguments(args)
|
|
||||||
infer()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
@ -1,13 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
@ -1,49 +0,0 @@
|
|||||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import paddle.fluid as fluid
|
|
||||||
|
|
||||||
|
|
||||||
def check_cuda(
|
|
||||||
use_cuda,
|
|
||||||
err="\nYou can not set use_cuda = True in the model because you are using paddlepaddle-cpu.\n \
|
|
||||||
Please: 1. Install paddlepaddle-gpu to run your models on GPU or 2. Set use_cuda = False to run models on CPU.\n"
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Log error and exit when set use_gpu=true in paddlepaddle
|
|
||||||
cpu version.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if use_cuda is True and fluid.is_compiled_with_cuda() is False:
|
|
||||||
print(err)
|
|
||||||
sys.exit(1)
|
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def check_version():
|
|
||||||
"""
|
|
||||||
Log error and exit when the installed version of paddlepaddle is
|
|
||||||
not satisfied.
|
|
||||||
"""
|
|
||||||
err = "PaddlePaddle version 1.6 or higher is required, " \
|
|
||||||
"or a suitable develop version is satisfied as well. \n" \
|
|
||||||
"Please make sure the version is good with your code." \
|
|
||||||
|
|
||||||
try:
|
|
||||||
fluid.require_version('1.6.0')
|
|
||||||
except Exception as e:
|
|
||||||
print(err)
|
|
||||||
sys.exit(1)
|
|
||||||
@ -1,169 +0,0 @@
|
|||||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""Evaluation for DeepSpeech2 model."""
|
|
||||||
import argparse
|
|
||||||
import functools
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import paddle
|
|
||||||
import paddle.fluid as fluid
|
|
||||||
from data_utils.data import DataGenerator
|
|
||||||
from model_utils.model_check import check_cuda
|
|
||||||
from model_utils.model_check import check_version
|
|
||||||
|
|
||||||
from deepspeech.models.ds2 import DeepSpeech2Model as DS2
|
|
||||||
from utils.error_rate import char_errors
|
|
||||||
from utils.error_rate import word_errors
|
|
||||||
from utils.utility import add_arguments
|
|
||||||
from utils.utility import print_arguments
|
|
||||||
parser = argparse.ArgumentParser(description=__doc__)
|
|
||||||
add_arg = functools.partial(add_arguments, argparser=parser)
|
|
||||||
# yapf: disable
|
|
||||||
add_arg('batch_size', int, 128, "Minibatch size.")
|
|
||||||
add_arg('beam_size', int, 500, "Beam search width.")
|
|
||||||
add_arg('feat_dim', int, 161, "Feature dim.")
|
|
||||||
add_arg('num_proc_bsearch', int, 8, "# of CPUs for beam search.")
|
|
||||||
add_arg('num_conv_layers', int, 2, "# of convolution layers.")
|
|
||||||
add_arg('num_rnn_layers', int, 3, "# of recurrent layers.")
|
|
||||||
add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.")
|
|
||||||
add_arg('alpha', float, 2.5, "Coef of LM for beam search.")
|
|
||||||
add_arg('beta', float, 0.3, "Coef of WC for beam search.")
|
|
||||||
add_arg('cutoff_prob', float, 1.0, "Cutoff probability for pruning.")
|
|
||||||
add_arg('cutoff_top_n', int, 40, "Cutoff number for pruning.")
|
|
||||||
add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.")
|
|
||||||
add_arg('use_gpu', bool, True, "Use GPU or not.")
|
|
||||||
add_arg('share_rnn_weights', bool, True, "Share input-hidden weights across "
|
|
||||||
"bi-directional RNNs. Not for GRU.")
|
|
||||||
add_arg('test_manifest', str,
|
|
||||||
'data/librispeech/manifest.test-clean',
|
|
||||||
"Filepath of manifest to evaluate.")
|
|
||||||
add_arg('mean_std_path', str,
|
|
||||||
'data/librispeech/mean_std.npz',
|
|
||||||
"Filepath of normalizer's mean & std.")
|
|
||||||
add_arg('vocab_path', str,
|
|
||||||
'data/librispeech/vocab.txt',
|
|
||||||
"Filepath of vocabulary.")
|
|
||||||
add_arg('model_path', str,
|
|
||||||
'./checkpoints/libri/step_final',
|
|
||||||
"If None, the training starts from scratch, "
|
|
||||||
"otherwise, it resumes from the pre-trained model.")
|
|
||||||
add_arg('lang_model_path', str,
|
|
||||||
'models/lm/common_crawl_00.prune01111.trie.klm',
|
|
||||||
"Filepath for language model.")
|
|
||||||
add_arg('decoding_method', str,
|
|
||||||
'ctc_beam_search',
|
|
||||||
"Decoding method. Options: ctc_beam_search, ctc_greedy",
|
|
||||||
choices=['ctc_beam_search', 'ctc_greedy'])
|
|
||||||
add_arg('error_rate_type', str,
|
|
||||||
'wer',
|
|
||||||
"Error rate type for evaluation.",
|
|
||||||
choices=['wer', 'cer'])
|
|
||||||
add_arg('specgram_type', str,
|
|
||||||
'linear',
|
|
||||||
"Audio feature type. Options: linear, mfcc.",
|
|
||||||
choices=['linear', 'mfcc'])
|
|
||||||
# yapf: disable
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
def evaluate():
|
|
||||||
"""Evaluate on whole test data for DeepSpeech2."""
|
|
||||||
|
|
||||||
# check if set use_gpu=True in paddlepaddle cpu version
|
|
||||||
check_cuda(args.use_gpu)
|
|
||||||
# check if paddlepaddle version is satisfied
|
|
||||||
check_version()
|
|
||||||
|
|
||||||
if args.use_gpu:
|
|
||||||
place = fluid.CUDAPlace(0)
|
|
||||||
else:
|
|
||||||
place = fluid.CPUPlace()
|
|
||||||
|
|
||||||
data_generator = DataGenerator(
|
|
||||||
vocab_filepath=args.vocab_path,
|
|
||||||
mean_std_filepath=args.mean_std_path,
|
|
||||||
augmentation_config='{}',
|
|
||||||
specgram_type=args.specgram_type,
|
|
||||||
keep_transcription_text=True,
|
|
||||||
place=place,
|
|
||||||
is_training=False)
|
|
||||||
batch_reader = data_generator.batch_reader_creator(
|
|
||||||
manifest_path=args.test_manifest,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
sortagrad=False,
|
|
||||||
shuffle_method=None)
|
|
||||||
|
|
||||||
|
|
||||||
# decoders only accept string encoded in utf-8
|
|
||||||
vocab_list = [chars for chars in data_generator.vocab_list]
|
|
||||||
for i, char in enumerate(vocab_list):
|
|
||||||
if vocab_list[i] == '':
|
|
||||||
vocab_list[i] = " "
|
|
||||||
|
|
||||||
model = DS2(
|
|
||||||
feat_size=args.feat_dim,
|
|
||||||
dict_size=len(vocab_list),
|
|
||||||
num_conv_layers=args.num_conv_layers,
|
|
||||||
num_rnn_layers=args.num_rnn_layers,
|
|
||||||
rnn_size=args.rnn_layer_size,
|
|
||||||
use_gru=args.use_gru,
|
|
||||||
share_rnn_weights=args.share_rnn_weights,
|
|
||||||
blank_id=len(vocab_list) - 1
|
|
||||||
)
|
|
||||||
|
|
||||||
params_path = args.model_path
|
|
||||||
model_dict = paddle.load(params_path)
|
|
||||||
model.set_state_dict(model_dict)
|
|
||||||
model.eval()
|
|
||||||
errors_func = char_errors if args.error_rate_type == 'cer' else word_errors
|
|
||||||
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
|
||||||
|
|
||||||
print("start evaluation ...")
|
|
||||||
for infer_data in batch_reader():
|
|
||||||
audio, target_transcripts, audio_len, mask = infer_data
|
|
||||||
audio = np.transpose(audio, (0, 2, 1))
|
|
||||||
audio_len = audio_len.reshape(-1)
|
|
||||||
audio = paddle.to_tensor(audio)
|
|
||||||
audio_len = paddle.to_tensor(audio_len)
|
|
||||||
result_transcripts = model.decode(
|
|
||||||
audio=audio,
|
|
||||||
audio_len=audio_len,
|
|
||||||
lang_model_path=args.lang_model_path,
|
|
||||||
decoding_method=args.decoding_method,
|
|
||||||
beam_alpha=args.alpha,
|
|
||||||
beam_beta=args.beta,
|
|
||||||
beam_size=args.beam_size,
|
|
||||||
cutoff_prob=args.cutoff_prob,
|
|
||||||
cutoff_top_n=args.cutoff_top_n,
|
|
||||||
vocab_list=vocab_list,
|
|
||||||
num_processes=args.num_proc_bsearch
|
|
||||||
)
|
|
||||||
for target, result in zip(target_transcripts, result_transcripts):
|
|
||||||
errors, len_ref = errors_func(target, result)
|
|
||||||
errors_sum += errors
|
|
||||||
len_refs += len_ref
|
|
||||||
num_ins += 1
|
|
||||||
print("Error rate [%s] (%d/?) = %f" %
|
|
||||||
(args.error_rate_type, num_ins, errors_sum / len_refs))
|
|
||||||
print("Final error rate [%s] (%d/%d) = %f" %
|
|
||||||
(args.error_rate_type, num_ins, num_ins, errors_sum / len_refs))
|
|
||||||
|
|
||||||
print("finish evaluation")
|
|
||||||
|
|
||||||
def main():
|
|
||||||
print_arguments(args)
|
|
||||||
evaluate()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue