fix logger and cmvn

pull/578/head
Hui Zhang 4 years ago
parent ee5a0c487f
commit ffb5756787

@ -22,6 +22,7 @@ from paddle.fluid import core
from paddle.nn import functional as F from paddle.nn import functional as F
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
#TODO(Hui Zhang): remove fluid import #TODO(Hui Zhang): remove fluid import
logger = Log(__name__).getlog() logger = Log(__name__).getlog()

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Contains feature normalizers.""" """Contains feature normalizers."""
import json
import random import random
import numpy as np import numpy as np
@ -22,12 +23,9 @@ from paddle.io import Dataset
from deepspeech.frontend.audio import AudioSegment from deepspeech.frontend.audio import AudioSegment
from deepspeech.frontend.utility import load_cmvn from deepspeech.frontend.utility import load_cmvn
from deepspeech.frontend.utility import read_manifest from deepspeech.frontend.utility import read_manifest
from deepspeech.utils.log import Log
__all__ = ["FeatureNormalizer"] __all__ = ["FeatureNormalizer"]
logger = Log(__name__).getlog()
class CollateFunc(object): class CollateFunc(object):
''' Collate function for AudioDataset ''' Collate function for AudioDataset
@ -171,7 +169,8 @@ class FeatureNormalizer(object):
collate_func = CollateFunc() collate_func = CollateFunc()
dataset = AudioDataset(manifest_path, featurize_func, num_samples) dataset = AudioDataset(manifest_path, featurize_func, num_samples,
self._rng)
batch_size = 20 batch_size = 20
data_loader = DataLoader( data_loader = DataLoader(
@ -198,8 +197,8 @@ class FeatureNormalizer(object):
wav_number += batch_size wav_number += batch_size
if wav_number % 1000 == 0: if wav_number % 1000 == 0:
logger.info('process {} wavs,{} frames'.format( print('process {} wavs,{} frames'.format(wav_number,
wav_number, int(all_number))) int(all_number)))
self.cmvn_info = { self.cmvn_info = {
'mean_stat': list(all_mean_stat.tolist()), 'mean_stat': list(all_mean_stat.tolist()),

@ -235,14 +235,6 @@ def _load_kaldi_cmvn(kaldi_cmvn_file):
return cmvn return cmvn
def _load_npz_cmvn(npz_cmvn_file, eps=1e-20):
npzfile = np.load(npz_cmvn_file)
means = npzfile["mean"] #(1, D)
istd = npzfile["istd"] #(1, D)
cmvn = np.array([means, istd])
return cmvn
def load_cmvn(cmvn_file: str, filetype: str): def load_cmvn(cmvn_file: str, filetype: str):
"""load cmvn from file. """load cmvn from file.
@ -262,8 +254,6 @@ def load_cmvn(cmvn_file: str, filetype: str):
cmvn = _load_json_cmvn(cmvn_file) cmvn = _load_json_cmvn(cmvn_file)
elif filetype == "kaldi": elif filetype == "kaldi":
cmvn = _load_kaldi_cmvn(cmvn_file) cmvn = _load_kaldi_cmvn(cmvn_file)
elif filetype == "npz":
cmvn = _load_npz_cmvn(cmvn_file)
else: else:
raise ValueError(f"cmvn file type no support: {filetype}") raise ValueError(f"cmvn file type no support: {filetype}")
return cmvn[0], cmvn[1] return cmvn[0], cmvn[1]

@ -16,7 +16,6 @@ import logging
import os import os
import socket import socket
import sys import sys
import time
def find_log_dir(log_dir=None): def find_log_dir(log_dir=None):
@ -106,16 +105,13 @@ class Log():
actual_log_dir, file_prefix, symlink_prefix = find_log_dir_and_names( actual_log_dir, file_prefix, symlink_prefix = find_log_dir_and_names(
program_name=None, log_dir=self.log_dir) program_name=None, log_dir=self.log_dir)
basename = '%s.INFO.%s.%d' % ( basename = '%s.DEBUG.%d' % (file_prefix, os.getpid())
file_prefix,
time.strftime('%Y%m%d-%H%M', time.localtime(time.time())),
os.getpid())
filename = os.path.join(actual_log_dir, basename) filename = os.path.join(actual_log_dir, basename)
if Log.log_name is None: if Log.log_name is None:
Log.log_name = filename Log.log_name = filename
# Create a symlink to the log file with a canonical name. # Create a symlink to the log file with a canonical name.
symlink = os.path.join(actual_log_dir, symlink_prefix + '.INFO') symlink = os.path.join(actual_log_dir, symlink_prefix + '.DEBUG')
try: try:
if os.path.islink(symlink): if os.path.islink(symlink):
os.unlink(symlink) os.unlink(symlink)
@ -126,26 +122,26 @@ class Log():
# we can't modify it # we can't modify it
pass pass
fh = logging.FileHandler(Log.log_name) if not self.logger.hasHandlers():
fh.setLevel(logging.DEBUG) format = '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
formatter = logging.Formatter(
fmt=format, datefmt='%Y/%m/%d %H:%M:%S')
fh = logging.FileHandler(Log.log_name)
fh.setFormatter(formatter)
fh.setLevel(logging.DEBUG)
self.logger.addHandler(fh)
ch = logging.StreamHandler() ch = logging.StreamHandler()
ch.setLevel(logging.INFO) ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
self.logger.addHandler(ch)
format = '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s' #fh.close()
formatter = logging.Formatter(fmt=format, datefmt='%Y/%m/%d %H:%M:%S') #ch.close()
fh.setFormatter(formatter)
ch.setFormatter(formatter)
self.logger.addHandler(fh)
self.logger.addHandler(ch)
# stop propagate for propagating may print # stop propagate for propagating may print
# log multiple times # log multiple times
# self.logger.propagate = False self.logger.propagate = False
fh.close()
ch.close()
def getlog(self): def getlog(self):
return self.logger return self.logger

@ -42,8 +42,9 @@ python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
--stride_ms=10.0 \ --stride_ms=10.0 \
--window_ms=25.0 \ --window_ms=25.0 \
--sample_rate=16000 \ --sample_rate=16000 \
--num_samples=2000 \
--num_workers=0 \ --num_workers=0 \
--output_path="data/mean_std.npz" --output_path="data/mean_std.json"
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Compute mean and stddev failed. Terminated." echo "Compute mean and stddev failed. Terminated."

@ -34,8 +34,8 @@ data:
# network architecture # network architecture
model: model:
cmvn_file: "data/mean_std.npz" cmvn_file: "data/mean_std.json"
cmvn_file_type: "npz" cmvn_file_type: "json"
# encoder related # encoder related
encoder: conformer encoder: conformer
encoder_conf: encoder_conf:

@ -43,7 +43,11 @@ python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
--specgram_type="fbank" \ --specgram_type="fbank" \
--feat_dim=80 \ --feat_dim=80 \
--delta_delta=false \ --delta_delta=false \
--output_path="data/mean_std.npz" --sample_rate=16000 \
--stride_ms=10.0 \
--window_ms=25.0 \
--num_workers=0 \
--output_path="data/mean_std.json"
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Compute mean and stddev failed. Terminated." echo "Compute mean and stddev failed. Terminated."

@ -74,8 +74,8 @@ data:
# network architecture # network architecture
model: model:
cmvn_file: "data/mean_std.npz" cmvn_file: "data/mean_std.json"
cmvn_file_type: "npz" cmvn_file_type: "json"
# encoder related # encoder related
encoder: conformer encoder: conformer
encoder_conf: encoder_conf:

@ -27,7 +27,7 @@ add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable # yapf: disable
add_arg('feat_type', str, "raw", "speech feature type, e.g. raw(wav, flac), kaldi") add_arg('feat_type', str, "raw", "speech feature type, e.g. raw(wav, flac), kaldi")
add_arg('cmvn_path', str, add_arg('cmvn_path', str,
'examples/librispeech/data/mean_std.npz', 'examples/librispeech/data/mean_std.json',
"Filepath of cmvn.") "Filepath of cmvn.")
add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm") add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm")
add_arg('vocab_path', str, add_arg('vocab_path', str,
@ -52,8 +52,8 @@ def main():
fout = open(args.output_path, 'w', encoding='utf-8') fout = open(args.output_path, 'w', encoding='utf-8')
# get feat dim # get feat dim
mean, std = load_cmvn(args.cmvn_path, filetype='npz') mean, std = load_cmvn(args.cmvn_path, filetype='json')
feat_dim = mean.shape[1] #(1, D) feat_dim = mean.shape[0] #(D)
print(f"Feature dim: {feat_dim}") print(f"Feature dim: {feat_dim}")
text_feature = TextFeaturizer(args.unit_type, args.vocab_path, args.spm_model_prefix) text_feature = TextFeaturizer(args.unit_type, args.vocab_path, args.spm_model_prefix)

Loading…
Cancel
Save