fix text featurizer

pull/578/head
Hui Zhang 5 years ago
parent 64f0bad5ca
commit af453e0234

@ -295,6 +295,7 @@
"source": [
"dataset = ManifestDataset(\n",
" config.data.test_manifest,\n",
" config.data.unit_type,\n",
" config.data.vocab_filepath,\n",
" config.data.mean_std_filepath,\n",
" augmentation_config=\"{}\",\n",

@ -85,8 +85,10 @@ def start_server(config, args):
"""Start the ASR server"""
dataset = ManifestDataset(
config.data.test_manifest,
config.data.unit_type,
config.data.vocab_filepath,
config.data.mean_std_filepath,
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config="{}",
max_duration=config.data.max_duration,
min_duration=config.data.min_duration,

@ -37,8 +37,10 @@ def start_server(config, args):
"""Start the ASR server"""
dataset = ManifestDataset(
config.data.test_manifest,
config.data.unit_type,
config.data.vocab_filepath,
config.data.mean_std_filepath,
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config="{}",
max_duration=config.data.max_duration,
min_duration=config.data.min_duration,

@ -43,8 +43,10 @@ def tune(config, args):
dev_dataset = ManifestDataset(
config.data.dev_manifest,
config.data.unit_type,
config.data.vocab_filepath,
config.data.mean_std_filepath,
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config="{}",
max_duration=config.data.max_duration,
min_duration=config.data.min_duration,

@ -21,7 +21,9 @@ _C.data = CN(
train_manifest="",
dev_manifest="",
test_manifest="",
unit_type="char",
vocab_filepath="",
spm_model_prefix="",
mean_std_filepath="",
augmentation_config="",
max_duration=float('inf'),

@ -148,8 +148,10 @@ class DeepSpeech2Trainer(Trainer):
train_dataset = ManifestDataset(
config.data.train_manifest,
config.data.unit_type,
config.data.vocab_filepath,
config.data.mean_std_filepath,
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config=io.open(
config.data.augmentation_config, mode='r',
encoding='utf8').read(),
@ -168,8 +170,10 @@ class DeepSpeech2Trainer(Trainer):
dev_dataset = ManifestDataset(
config.data.dev_manifest,
config.data.unit_type,
config.data.vocab_filepath,
config.data.mean_std_filepath,
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config="{}",
max_duration=config.data.max_duration,
min_duration=config.data.min_duration,
@ -361,8 +365,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
# return raw text
test_dataset = ManifestDataset(
config.data.test_manifest,
config.data.unit_type,
config.data.vocab_filepath,
config.data.mean_std_filepath,
spm_model_prefix=config.data.spm_model_prefix,
augmentation_config="{}",
max_duration=config.data.max_duration,
min_duration=config.data.min_duration,

@ -109,7 +109,7 @@ class AudioFeaturizer(object):
feat_dim = int(fft_point * (self._target_sample_rate / 1000) / 2 +
1)
elif self._specgram_type == 'mfcc':
# mfcc,delta, delta-delta
# mfcc, delta, delta-delta
feat_dim = int(13 * 3)
else:
raise ValueError("Unknown specgram_type %s. "

@ -52,7 +52,9 @@ class SpeechFeaturizer(object):
"""
def __init__(self,
unit_type,
vocab_filepath,
spm_model_prefix=None,
specgram_type='linear',
stride_ms=10.0,
window_ms=20.0,
@ -70,7 +72,8 @@ class SpeechFeaturizer(object):
target_sample_rate=target_sample_rate,
use_dB_normalization=use_dB_normalization,
target_dB=target_dB)
self._text_featurizer = TextFeaturizer(vocab_filepath)
self._text_featurizer = TextFeaturizer(unit_type, vocab_filepath,
spm_model_prefix)
def featurize(self, speech_segment, keep_transcription_text):
"""Extract features for speech segment.

@ -15,25 +15,35 @@
import os
import codecs
import sentencepiece as spm
from deepspeech.frontend.utility import UNK
class TextFeaturizer(object):
"""Text featurizer, for processing or extracting features from text.
Currently, it only supports char-level tokenizing and conversion into
a list of token indices. Note that the token indexing order follows the
given vocabulary file.
class TextFeaturizer(object):
def __init__(self, unit_type, vocab_filepath, spm_model_prefix=None):
"""Text featurizer, for processing or extracting features from text.
:param vocab_filepath: Filepath to load vocabulary for token indices
conversion.
:type specgram_type: str
"""
Currently, it supports char/word/sentence-piece level tokenizing and conversion into
a list of token indices. Note that the token indexing order follows the
given vocabulary file.
def __init__(self, vocab_filepath):
self.unk = '<unk>'
Args:
unit_type (str): unit type, e.g. char, word, spm
vocab_filepath (str): Filepath to load vocabulary for token indices conversion.
spm_model_prefix (str, optional): spm model prefix. Defaults to None.
"""
assert unit_type in ('char', 'spm', 'word')
self.unk = UNK
self.unit_type = unit_type
self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file(
vocab_filepath)
if unit_type == 'spm':
spm_model = spm_model_prefix + '.model'
self.sp = spm.SentencePieceProcessor()
self.sp.Load(self.spm_model)
def featurize(self, text):
"""Convert text string to a list of token indices in char-level.Note
that the token indexing order follows the given vocabulary file.
@ -43,7 +53,13 @@ class TextFeaturizer(object):
:return: List of char-level token indices.
:rtype: list
"""
tokens = self._char_tokenize(text)
if unit_type == 'char':
tokens = self._char_tokenize(text)
elif unit_type == 'word':
tokens = self._word_tokenize(text)
else:
tokens = self._spm_tokenize(text)
ids = []
for token in tokens:
token = token if token in self._vocab_dict else self.unk
@ -72,6 +88,42 @@ class TextFeaturizer(object):
"""Character tokenizer."""
return list(text.strip())
def _word_tokenize(self, text):
"""Word tokenizer, spearte by <space>."""
return text.strip().split()
def _spm_tokenize(self, text):
"""spm tokenize.
Args:
text (str): text string.
Returns:
List[str]: sentence pieces str code
"""
stats = {"num_empty": 0, "num_filtered": 0}
def valid(line):
return True
def encode(l):
return self.sp.EncodeAsPieces(l)
def encode_line(line):
line = line.strip()
if len(line) > 0:
line = encode(line)
if valid(line):
return line
else:
stats["num_filtered"] += 1
else:
stats["num_empty"] += 1
return None
enc_line = encode_line(text)
return enc_line
def _load_vocabulary_from_file(self, vocab_filepath):
"""Load vocabulary from file."""
vocab_lines = []

@ -16,6 +16,7 @@
import numpy as np
import random
from deepspeech.frontend.utility import read_manifest
from deepspeech.frontend.utility import load_cmvn
from deepspeech.frontend.audio import AudioSegment
@ -79,10 +80,8 @@ class FeatureNormalizer(object):
def _read_mean_std_from_file(self, filepath, eps=1e-20):
"""Load mean and std from file."""
npzfile = np.load(filepath)
self._mean = npzfile["mean"]
std = npzfile["std"]
std = np.clip(std, eps, None)
mean, std = load_cmvn(filepath, filetype='npz')
self._mean = mean
self._istd = 1.0 / std
def _compute_mean_std(self, manifest_path, featurize_func, num_samples):
@ -92,8 +91,7 @@ class FeatureNormalizer(object):
features = []
for instance in sampled_manifest:
features.append(
featurize_func(
AudioSegment.from_file(instance["audio_filepath"])))
featurize_func(AudioSegment.from_file(instance["feat"])))
features = np.hstack(features) #(D, T)
self._mean = np.mean(features, axis=1).reshape([-1, 1]) #(D, 1)
self._std = np.std(features, axis=1).reshape([-1, 1]) #(D, 1)

@ -20,6 +20,7 @@ import os
import tarfile
import time
import logging
from typing import List
from threading import Thread
from multiprocessing import Process, Manager, Value
@ -39,31 +40,32 @@ EOS = SOS
UNK = "<unk>"
BLANK = "<blank>"
# """Load and parse manifest file.
# Instances with durations outside [min_duration, max_duration] will be
# filtered out.
# :param manifest_path: Manifest file to load and parse.
# :type manifest_path: str
# :param max_duration:maximum output seq length, in seconds for raw wav, in frame numbers for feature data.
# :type max_duration: float
# :param min_duration: minimum input seq length, in seconds for raw wav, in frame numbers for feature data.
# :type min_duration: float
# :return: Manifest parsing results. List of dict.
# :rtype: list
# :raises IOError: If failed to parse the manifest.
# """
def read_manifest(
manifest_path,
max_input_len=float('inf'),
min_input_len=0.0,
max_output_len=500.0,
max_output_len=float('inf'),
min_output_len=0.0,
max_output_input_ratio=10.0,
min_output_input_ratio=0.05, ):
max_output_input_ratio=float('inf'),
min_output_input_ratio=0.0, ):
"""Load and parse manifest file.
Args:
manifest_path ([type]): Manifest file to load and parse.
max_input_len ([type], optional): maximum output seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to float('inf').
min_input_len (float, optional): minimum input seq length, in seconds for raw wav, in frame numbers for feature data. Defaults to 0.0.
max_output_len (float, optional): maximum input seq length, in modeling units. Defaults to 500.0.
min_output_len (float, optional): minimum input seq length, in modeling units. Defaults to 0.0.
max_output_input_ratio (float, optional): maximum output seq length/output seq length ratio. Defaults to 10.0.
min_output_input_ratio (float, optional): minimum output seq length/output seq length ratio. Defaults to 0.05.
Raises:
IOError: If failed to parse the manifest.
Returns:
List[dict]: Manifest parsing results.
"""
manifest = []
for json_line in codecs.open(manifest_path, 'r', 'utf-8'):
@ -71,33 +73,23 @@ def read_manifest(
json_data = json.loads(json_line)
except Exception as e:
raise IOError("Error reading manifest: %s" % str(e))
feat_len = json_data["feat_shape"][0]
token_len = json_data["token_shape"][0]
feat_len = json_data["feat_shape"][
0] if 'feat_shape' in json_data else 1.0
token_len = json_data["token_shape"][
0] if 'token_shape' in json_data else 1.0
conditions = [
feat_len > min_input_len,
feat_len < max_input_len,
token_len > min_output_len,
token_len < max_output_len,
token_len / feat_len > min_output_input_ratio,
token_len / feat_len < max_output_input_ratio,
feat_len >= min_input_len,
feat_len <= max_input_len,
token_len >= min_output_len,
token_len <= max_output_len,
token_len / feat_len >= min_output_input_ratio,
token_len / feat_len <= max_output_input_ratio,
]
if all(conditions):
manifest.append(json_data)
return manifest
# parser.add_argument('--max_input_len', type=float,
# default=20,
# help='maximum output seq length, in seconds for raw wav, in frame numbers for feature data')
# parser.add_argument('--min_output_len', type=float,
# default=0, help='minimum input seq length, in modeling units')
# parser.add_argument('--max_output_len', type=float,
# default=500,
# help='maximum output seq length, in modeling units')
# parser.add_argument('--min_output_input_ratio', type=float, default=0.05,
# help='minimum output seq length/output seq length ratio')
# parser.add_argument('--max_output_input_ratio', type=float, default=10,
# help='maximum output seq length/output seq length ratio')
def rms_to_db(rms: float):
"""Root Mean Square to dB.
@ -251,8 +243,8 @@ def _load_kaldi_cmvn(kaldi_cmvn_file):
def _load_npz_cmvn(npz_cmvn_file, eps=1e-20):
npzfile = np.load(npz_cmvn_file)
means = npzfile["mean"]
std = npzfile["std"]
means = npzfile["mean"] #(D, 1)
std = npzfile["std"] #(D, 1)
std = np.clip(std, eps, None)
variance = 1.0 / std
cmvn = np.array([means, variance])
@ -278,7 +270,7 @@ def load_cmvn(cmvn_file: str, filetype: str):
cmvn = _load_json_cmvn(cmvn_file)
elif filetype == "kaldi":
cmvn = _load_kaldi_cmvn(cmvn_file)
elif filtype == "npz":
elif filetype == "npz":
cmvn = _load_npz_cmvn(cmvn_file)
else:
raise ValueError(f"cmvn file type no support: {filetype}")

@ -21,8 +21,10 @@ from deepspeech.io.dataset import ManifestDataset
def create_dataloader(manifest_path,
unit_type,
vocab_filepath,
mean_std_filepath,
spm_model_prefix,
augmentation_config='{}',
max_duration=float('inf'),
min_duration=0.0,
@ -42,8 +44,10 @@ def create_dataloader(manifest_path,
dataset = ManifestDataset(
manifest_path,
unit_type,
vocab_filepath,
mean_std_filepath,
spm_model_prefix=spm_model_prefix,
augmentation_config=augmentation_config,
max_duration=max_duration,
min_duration=min_duration,

@ -38,8 +38,10 @@ __all__ = [
class ManifestDataset(Dataset):
def __init__(self,
manifest_path,
unit_type,
vocab_filepath,
mean_std_filepath,
spm_model_prefix=None,
augmentation_config='{}',
max_duration=float('inf'),
min_duration=0.0,
@ -57,8 +59,10 @@ class ManifestDataset(Dataset):
Args:
manifest_path (str): manifest josn file path
vocab_filepath (str): vocab file path
unit_type(str): token unit type, e.g. char, word, spm
vocab_filepath (str): vocab file path.
mean_std_filepath (str): mean and std file path, which suffix is *.npy
spm_model_prefix (str): spm model prefix, need if `unit_type` is spm.
augmentation_config (str, optional): augmentation json str. Defaults to '{}'.
max_duration (float, optional): audio length in seconds must less than this. Defaults to float('inf').
min_duration (float, optional): audio length is seconds must greater than this. Defaults to 0.0.
@ -78,10 +82,12 @@ class ManifestDataset(Dataset):
self._max_duration = max_duration
self._min_duration = min_duration
self._normalizer = FeatureNormalizer(mean_std_filepath)
self._augmentation_pipeline = AugmentationPipeline(
self._audio_augmentation_pipeline = AugmentationPipeline(
augmentation_config=augmentation_config, random_seed=random_seed)
self._speech_featurizer = SpeechFeaturizer(
unit_type=unit_type,
vocab_filepath=vocab_filepath,
spm_model_prefix=spm_model_prefix,
specgram_type=specgram_type,
stride_ms=stride_ms,
window_ms=window_ms,
@ -174,7 +180,7 @@ class ManifestDataset(Dataset):
self._subfile_from_tar(audio_file), transcript)
else:
speech_segment = SpeechSegment.from_file(audio_file, transcript)
self._augmentation_pipeline.transform_audio(speech_segment)
self._audio_augmentation_pipeline.transform_audio(speech_segment)
specgram, transcript_part = self._speech_featurizer.featurize(
speech_segment, self._keep_transcription_text)
specgram = self._normalizer.apply(specgram)
@ -191,7 +197,7 @@ class ManifestDataset(Dataset):
def reader():
for instance in manifest:
inst = self.process_utterance(instance["audio_filepath"],
inst = self.process_utterance(instance["feat"],
instance["text"])
yield inst
@ -202,5 +208,4 @@ class ManifestDataset(Dataset):
def __getitem__(self, idx):
instance = self._manifest[idx]
return self.process_utterance(instance["audio_filepath"],
instance["text"])
return self.process_utterance(instance["feat"], instance["text"])

@ -52,6 +52,7 @@ fi
# format manifest with tokenids, vocab size
python3 ${MAIN_ROOT}/utils/format_data.py \
--feat_type "raw" \
--cmvn_path "data/mean_std.npz" \
--unit_type "bpe" \
--bpe_model_prefix ${bpeprefix} \
--vocab_path="data/vocab.txt" \

@ -3,6 +3,7 @@ resampy==0.2.2
SoundFile==0.9.0.post1
python_speech_features
tensorboardX
sentencepiece
yacs
typeguard
pre-commit

@ -24,6 +24,7 @@ from deepspeech.frontend.utility import read_manifest
from deepspeech.frontend.utility import UNK
from deepspeech.frontend.utility import BLANK
from deepspeech.frontend.utility import SOS
from deepspeech.frontend.utility import load_cmvn
from deepspeech.utils.utility import add_arguments
from deepspeech.utils.utility import print_arguments
@ -31,10 +32,13 @@ parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('feat_type', str, "raw", "speech feature type, e.g. raw(wav, flac), kaldi")
add_arg('cmvn_path', str,
'examples/librispeech/data/mean_std.npz',
"Filepath of cmvn.")
add_arg('unit_type', str, "character", "Unit type, e.g. character, word, bpe")
add_arg('vocab_path', str,
'examples/librispeech/data/vocab.txt',
"Filepath to write the vocabulary.")
"Filepath of the vocabulary.")
add_arg('manifest_paths', str,
None,
"Filepaths of manifests for building vocabulary. "
@ -51,6 +55,11 @@ args = parser.parse_args()
def main():
print_arguments(args)
# get feat dim
mean, std = load_cmvn(args.cmvn_path, filetype='npz')
feat_dim = mean.shape[0]
print(f"Feature dim: {feat_dim}")
# read vocab
vocab = dict()
with open(args.vocab_path, 'r', encoding='utf-8') as fin:
@ -58,6 +67,7 @@ def main():
token = line.strip()
vocab[token] = len(vocab)
vocab_size = len(vocab)
print(f"Vocab size: {vocab_size}")
fout = open(args.output_path, 'w', encoding='utf-8')
@ -78,6 +88,12 @@ def main():
line_json['token'] = tokens
line_json['token_id'] = tokenids
line_json['token_shape'] = (len(tokenids), vocab_size)
feat_shape = line_json['feat_shape']
assert isinstance(feat_shape, (list, tuple)), type(feat_shape)
if args.feat_type == 'raw':
feat_shape.append(feat_dim)
else: # kaldi
raise NotImplemented('no support kaldi feat now!')
fout.write(json.dumps(line_json) + '\n')
else:
import sentencepiece as spm
@ -118,6 +134,12 @@ def main():
line_json['token'] = tokens
line_json['token_id'] = tokenids
line_json['token_shape'] = (len(tokenids), vocab_size)
feat_shape = line_json['feat_shape']
assert isinstance(feat_shape, (list, tuple)), type(feat_shape)
if args.feat_type == 'raw':
feat_shape.append(feat_dim)
else: # kaldi
raise NotImplemented('no support kaldi feat now!')
fout.write(json.dumps(line_json) + '\n')
fout.close()

Loading…
Cancel
Save