add feature pipeline layer(cmvn, fbank), but to_static and jit.layer output is not equal

pull/2212/head
Hui Zhang 2 years ago
parent 67709155e9
commit 8690a00bd8

@ -74,16 +74,16 @@ def _feature_window_function(
window_size: int, window_size: int,
blackman_coeff: float, blackman_coeff: float,
dtype: int, ) -> Tensor: dtype: int, ) -> Tensor:
if window_type == HANNING: if window_type == "hann":
return get_window('hann', window_size, fftbins=False, dtype=dtype) return get_window('hann', window_size, fftbins=False, dtype=dtype)
elif window_type == HAMMING: elif window_type == "hamming":
return get_window('hamming', window_size, fftbins=False, dtype=dtype) return get_window('hamming', window_size, fftbins=False, dtype=dtype)
elif window_type == POVEY: elif window_type == "povey":
return get_window( return get_window(
'hann', window_size, fftbins=False, dtype=dtype).pow(0.85) 'hann', window_size, fftbins=False, dtype=dtype).pow(0.85)
elif window_type == RECTANGULAR: elif window_type == "rect":
return paddle.ones([window_size], dtype=dtype) return paddle.ones([window_size], dtype=dtype)
elif window_type == BLACKMAN: elif window_type == "blackman":
a = 2 * math.pi / (window_size - 1) a = 2 * math.pi / (window_size - 1)
window_function = paddle.arange(window_size, dtype=dtype) window_function = paddle.arange(window_size, dtype=dtype)
return (blackman_coeff - 0.5 * paddle.cos(a * window_function) + return (blackman_coeff - 0.5 * paddle.cos(a * window_function) +
@ -216,7 +216,7 @@ def spectrogram(waveform: Tensor,
sr: int=16000, sr: int=16000,
snip_edges: bool=True, snip_edges: bool=True,
subtract_mean: bool=False, subtract_mean: bool=False,
window_type: str=POVEY) -> Tensor: window_type: str="povey") -> Tensor:
"""Compute and return a spectrogram from a waveform. The output is identical to Kaldi's. """Compute and return a spectrogram from a waveform. The output is identical to Kaldi's.
Args: Args:
@ -236,7 +236,7 @@ def spectrogram(waveform: Tensor,
snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a singal frame when it snip_edges (bool, optional): Drop samples in the end of waveform that cann't fit a singal frame when it
is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True. is set True. Otherwise performs reflect padding to the end of waveform. Defaults to True.
subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False. subtract_mean (bool, optional): Whether to subtract mean of feature files. Defaults to False.
window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY. window_type (str, optional): Choose type of window for FFT computation. Defaults to "povey".
Returns: Returns:
Tensor: A spectrogram tensor with shape `(m, padded_window_size // 2 + 1)` where m is the number of frames Tensor: A spectrogram tensor with shape `(m, padded_window_size // 2 + 1)` where m is the number of frames
@ -418,11 +418,11 @@ def fbank(waveform: Tensor,
vtln_high: float=-500.0, vtln_high: float=-500.0,
vtln_low: float=100.0, vtln_low: float=100.0,
vtln_warp: float=1.0, vtln_warp: float=1.0,
window_type: str=POVEY) -> Tensor: window_type: str="povey") -> Tensor:
"""Compute and return filter banks from a waveform. The output is identical to Kaldi's. """Compute and return filter banks from a waveform. The output is identical to Kaldi's.
Args: Args:
waveform (Tensor): A waveform tensor with shape `(C, T)`. waveform (Tensor): A waveform tensor with shape `(C, T)`. `C` is in the range [0,1].
blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42. blackman_coeff (float, optional): Coefficient for Blackman window.. Defaults to 0.42.
channel (int, optional): Select the channel of waveform. Defaults to -1. channel (int, optional): Select the channel of waveform. Defaults to -1.
dither (float, optional): Dithering constant . Defaults to 0.0. dither (float, optional): Dithering constant . Defaults to 0.0.
@ -448,7 +448,7 @@ def fbank(waveform: Tensor,
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function. Defaults to -500.0. vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function. Defaults to -500.0.
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function. Defaults to 100.0. vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function. Defaults to 100.0.
vtln_warp (float, optional): Vtln warp factor. Defaults to 1.0. vtln_warp (float, optional): Vtln warp factor. Defaults to 1.0.
window_type (str, optional): Choose type of window for FFT computation. Defaults to POVEY. window_type (str, optional): Choose type of window for FFT computation. Defaults to "povey".
Returns: Returns:
Tensor: A filter banks tensor with shape `(m, n_mels)`. Tensor: A filter banks tensor with shape `(m, n_mels)`.
@ -537,7 +537,7 @@ def mfcc(waveform: Tensor,
vtln_high: float=-500.0, vtln_high: float=-500.0,
vtln_low: float=100.0, vtln_low: float=100.0,
vtln_warp: float=1.0, vtln_warp: float=1.0,
window_type: str=POVEY) -> Tensor: window_type: str="povey") -> Tensor:
"""Compute and return mel frequency cepstral coefficients from a waveform. The output is """Compute and return mel frequency cepstral coefficients from a waveform. The output is
identical to Kaldi's. identical to Kaldi's.

@ -18,6 +18,7 @@ from pathlib import Path
import paddle import paddle
import soundfile import soundfile
import numpy as np
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.audio.transform.transformation import Transformation from paddlespeech.audio.transform.transformation import Transformation
@ -77,6 +78,8 @@ class U2Infer():
feat = self.preprocessing(audio, **self.preprocess_args) feat = self.preprocessing(audio, **self.preprocess_args)
logger.info(f"feat shape: {feat.shape}") logger.info(f"feat shape: {feat.shape}")
np.savetxt("feat.transform.txt", feat)
ilen = paddle.to_tensor(feat.shape[0]) ilen = paddle.to_tensor(feat.shape[0])
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0) xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0)
decode_config = self.config.decode decode_config = self.config.decode

@ -474,13 +474,20 @@ class U2Tester(U2Trainer):
def export(self): def export(self):
infer_model, input_spec = self.load_inferspec() infer_model, input_spec = self.load_inferspec()
infer_model.eval() infer_model.eval()
paddle.set_device('cpu')
assert isinstance(input_spec, list), type(input_spec) assert isinstance(input_spec, (list, tuple)), type(input_spec)
batch_size, feat_dim, model_size, num_left_chunks = input_spec batch_size, feat_dim, model_size, num_left_chunks = input_spec
######################### infer_model.forward_encoder_chunk zero tensor online ############ ######################## infer_model.forward_encoder_chunk ############
# TODO: 80(feature dim) be configable input_spec = [
# (T,), int16
paddle.static.InputSpec(shape=[None], dtype='int16'),
]
infer_model.forward_feature = paddle.jit.to_static(infer_model.forward_feature, input_spec=input_spec)
######################### infer_model.forward_encoder_chunk ############
input_spec = [ input_spec = [
# xs, (B, T, D) # xs, (B, T, D)
paddle.static.InputSpec(shape=[batch_size, None, feat_dim], dtype='float32'), paddle.static.InputSpec(shape=[batch_size, None, feat_dim], dtype='float32'),
@ -499,8 +506,16 @@ class U2Tester(U2Trainer):
infer_model.forward_encoder_chunk = paddle.jit.to_static( infer_model.forward_encoder_chunk = paddle.jit.to_static(
infer_model.forward_encoder_chunk, input_spec=input_spec) infer_model.forward_encoder_chunk, input_spec=input_spec)
######################### infer_model.ctc_activation ########################
input_spec = [
# encoder_out, (B,T,D)
paddle.static.InputSpec(shape=[batch_size, None, model_size], dtype='float32')
]
infer_model.ctc_activation = paddle.jit.to_static(
infer_model.ctc_activation, input_spec=input_spec)
######################### infer_model.forward_attention_decoder ######################## ######################### infer_model.forward_attention_decoder ########################
# TODO: 512(encoder_output) be configable. 1 for BatchSize
input_spec = [ input_spec = [
# hyps, (B, U) # hyps, (B, U)
paddle.static.InputSpec(shape=[None, None], dtype='int64'), paddle.static.InputSpec(shape=[None, None], dtype='int64'),
@ -512,17 +527,11 @@ class U2Tester(U2Trainer):
infer_model.forward_attention_decoder = paddle.jit.to_static( infer_model.forward_attention_decoder = paddle.jit.to_static(
infer_model.forward_attention_decoder, input_spec=input_spec) infer_model.forward_attention_decoder, input_spec=input_spec)
######################### infer_model.ctc_activation ########################
input_spec = [
# encoder_out, (B,T,D)
paddle.static.InputSpec(shape=[batch_size, None, model_size], dtype='float32')
]
infer_model.ctc_activation = paddle.jit.to_static(
infer_model.ctc_activation, input_spec=input_spec)
# jit save # jit save
logger.info(f"export save: {self.args.export_path}")
paddle.jit.save(infer_model, self.args.export_path, combine_params=True, skip_forward=True) paddle.jit.save(infer_model, self.args.export_path, combine_params=True, skip_forward=True)
# test dy2static # test dy2static
def flatten(out): def flatten(out):
if isinstance(out, paddle.Tensor): if isinstance(out, paddle.Tensor):
@ -536,26 +545,44 @@ class U2Tester(U2Trainer):
flatten_out.append(var) flatten_out.append(var)
return flatten_out return flatten_out
xs1 = paddle.rand(shape=[1, 67, 80], dtype='float32') # forward_encoder_chunk dygraph
xs1 = paddle.full([1, 67, 80], 0.1, dtype='float32')
offset = paddle.to_tensor([0], dtype='int32') offset = paddle.to_tensor([0], dtype='int32')
required_cache_size = num_left_chunks required_cache_size = num_left_chunks
att_cache = paddle.zeros([0, 0, 0, 0]) att_cache = paddle.zeros([0, 0, 0, 0])
cnn_cache = paddle.zeros([0, 0, 0, 0]) cnn_cache = paddle.zeros([0, 0, 0, 0])
xs_d, att_cache_d, cnn_cache_d = infer_model.forward_encoder_chunk(xs1, offset, required_cache_size, att_cache, cnn_cache)
xs, att_cache, cnn_cache = infer_model.forward_encoder_chunk(xs1, offset, required_cache_size, att_cache, cnn_cache) import soundfile
xs2 = paddle.rand(shape=[1, 67, 80], dtype='float32') audio, sample_rate = soundfile.read(
offset = paddle.to_tensor([16], dtype='int32') './zh.wav', dtype="int16", always_2d=True)
out1 = infer_model.forward_encoder_chunk(xs2, offset, required_cache_size, att_cache, cnn_cache) audio = audio[:, 0]
print('py encoder', out1) logger.info(f"audio shape: {audio.shape}")
audio = paddle.to_tensor(audio, paddle.int16)
feat_d = infer_model.forward_feature(audio)
logger.info(f"{feat_d}")
np.savetxt("feat.tostatic.txt", feat_d)
# load static model
from paddle.jit.layer import Layer from paddle.jit.layer import Layer
layer = Layer() layer = Layer()
layer.load(self.args.export_path, paddle.CPUPlace()) layer.load(self.args.export_path, paddle.CPUPlace())
xs1 = paddle.full([1, 7, 80], 0.1, dtype='float32') # forward_encoder_chunk static
xs1 = paddle.full([1, 67, 80], 0.1, dtype='float32')
offset = paddle.to_tensor([0], dtype='int32') offset = paddle.to_tensor([0], dtype='int32')
att_cache = paddle.zeros([0, 0, 0, 0]) att_cache = paddle.zeros([0, 0, 0, 0])
cnn_cache=paddle.zeros([0, 0, 0, 0]) cnn_cache = paddle.zeros([0, 0, 0, 0])
func = getattr(layer, 'forward_encoder_chunk') func = getattr(layer, 'forward_encoder_chunk')
xs, att_cache, cnn_cache = func(xs1, offset, att_cache, cnn_cache) xs_s, att_cache_s, cnn_cache_s = func(xs1, offset, att_cache, cnn_cache)
print('py static encoder', xs) np.testing.assert_allclose(xs_d, xs_s, atol=1e-5)
np.testing.assert_allclose(att_cache_d, att_cache_s, atol=1e-4)
np.testing.assert_allclose(cnn_cache_d, cnn_cache_s, atol=1e-4)
# logger.info(f"forward_encoder_chunk output: {xs_s}")
# forward_feature static
func = getattr(layer, 'forward_feature')
feat_s = func(audio)[0]
logger.info(f"{feat_s}")
np.testing.assert_allclose(feat_d, feat_s, atol=1e-5)

@ -916,6 +916,50 @@ class U2InferModel(U2Model):
def __init__(self, configs: dict): def __init__(self, configs: dict):
super().__init__(configs) super().__init__(configs)
from paddlespeech.s2t.modules.fbank import KaldiFbank
import yaml
import json
import numpy as np
input_dim = configs['input_dim']
process = configs['preprocess_config']
with open(process, encoding="utf-8") as f:
conf = yaml.safe_load(f)
assert isinstance(conf, dict), type(self.conf)
for idx, process in enumerate(conf['process']):
assert isinstance(process, dict), type(process)
opts = dict(process)
process_type = opts.pop("type")
if process_type == 'fbank_kaldi':
opts.update({'n_mels': input_dim})
opts['dither'] = 0.0
self.fbank = KaldiFbank(
**opts
)
logger.info(f"{self.__class__.__name__} export: {self.fbank}")
if process_type == 'cmvn_json':
# align with paddlespeech.audio.transform.cmvn:GlobalCMVN
std_floor = 1.0e-20
cmvn = opts['cmvn_path']
if isinstance(cmvn, dict):
cmvn_stats = cmvn
else:
with open(cmvn) as f:
cmvn_stats = json.load(f)
count = cmvn_stats['frame_num']
mean = np.array(cmvn_stats['mean_stat']) / count
square_sums = np.array(cmvn_stats['var_stat'])
var = square_sums / count - mean**2
std = np.maximum(np.sqrt(var), std_floor)
istd = 1.0 / std
self.global_cmvn = GlobalCMVN(
paddle.to_tensor(mean, dtype=paddle.float),
paddle.to_tensor(istd, dtype=paddle.float))
logger.info(f"{self.__class__.__name__} export: {self.global_cmvn}")
def forward(self, def forward(self,
feats, feats,
feats_lengths, feats_lengths,
@ -939,3 +983,17 @@ class U2InferModel(U2Model):
# num_decoding_left_chunks=num_decoding_left_chunks, # num_decoding_left_chunks=num_decoding_left_chunks,
# simulate_streaming=simulate_streaming) # simulate_streaming=simulate_streaming)
return feats, feats_lengths return feats, feats_lengths
def forward_feature(self, x):
"""feature pipeline.
Args:
x (paddle.Tensor): waveform (T,).
Return:
feat (paddle.Tensor): feature (T, D)
"""
x = paddle.cast(x, paddle.float32)
feat = self.fbank(x)
feat = self.global_cmvn(feat)
return feat

@ -40,6 +40,14 @@ class GlobalCMVN(nn.Layer):
self.register_buffer("mean", mean) self.register_buffer("mean", mean)
self.register_buffer("istd", istd) self.register_buffer("istd", istd)
def __repr__(self):
return (
"{name}(mean={mean}, istd={istd}, norm_var={norm_var})".format(
name=self.__class__.__name__,
mean=self.mean,
istd=self.istd,
norm_var=self.norm_var))
def forward(self, x: paddle.Tensor): def forward(self, x: paddle.Tensor):
""" """
Args: Args:

@ -0,0 +1,74 @@
import paddle
from paddle import nn
from paddlespeech.audio.compliance import kaldi
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ['KaldiFbank']
class KaldiFbank(nn.Layer):
def __init__(self,
fs=16000,
n_mels=80,
n_shift=160, # unit:sample, 10ms
win_length=400, # unit:sample, 25ms
energy_floor=0.0,
dither=0.0):
"""
Args:
fs (int): sample rate of the audio
n_mels (int): number of mel filter banks
n_shift (int): number of points in a frame shift
win_length (int): number of points in a frame windows
energy_floor (float): Floor on energy in Spectrogram computation (absolute)
dither (float): Dithering constant. Default 0.0
"""
super().__init__()
self.fs = fs
self.n_mels = n_mels
num_point_ms = fs / 1000
self.n_frame_length = win_length / num_point_ms
self.n_frame_shift = n_shift / num_point_ms
self.energy_floor = energy_floor
self.dither = dither
def __repr__(self):
return (
"{name}(fs={fs}, n_mels={n_mels}, "
"n_frame_shift={n_frame_shift}, n_frame_length={n_frame_length}, "
"dither={dither}))".format(
name=self.__class__.__name__,
fs=self.fs,
n_mels=self.n_mels,
n_frame_shift=self.n_frame_shift,
n_frame_length=self.n_frame_length,
dither=self.dither, ))
def forward(self, x: paddle.Tensor):
"""
Args:
x (paddle.Tensor): shape (Ti).
Not support: [Time, Channel] and Batch mode.
Returns:
paddle.Tensor: (T, D)
"""
assert x.ndim == 1
feat = kaldi.fbank(
x.unsqueeze(0), # append channel dim, (C, Ti)
n_mels=self.n_mels,
frame_length=self.n_frame_length,
frame_shift=self.n_frame_shift,
dither=self.dither,
energy_floor=self.energy_floor,
sr=self.fs)
assert feat.ndim == 2 # (T,D)
return feat
Loading…
Cancel
Save