add josn global cmvn

pull/1012/head
Hui Zhang 3 years ago
parent 9cdd2643b1
commit 6a7e0265cd

@ -1,22 +1,24 @@
process:
# extract kaldi fbank from PCM
- type: "fbank_kaldi"
- type: fbank_kaldi
fs: 16000
n_mels: 80
n_shift: 160
win_length: 400
dither: true
- type: cmvn_json
cmvn_path: data/mean_std.json
# these three processes are a.k.a. SpecAugument
- type: "time_warp"
- type: time_warp
max_time_warp: 5
inplace: true
mode: "PIL"
- type: "freq_mask"
mode: PIL
- type: freq_mask
F: 30
n_mask: 2
inplace: true
replace_with_zero: false
- type: "time_mask"
- type: time_mask
T: 40
n_mask: 2
inplace: true

@ -11,7 +11,7 @@ data:
max_output_input_ratio: 10.0
collator:
mean_std_filepath: ""
mean_std_filepath: data/mean_std.json
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_200'
@ -37,7 +37,7 @@ collator:
# network architecture
model:
cmvn_file: "data/mean_std.json"
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: transformer

@ -13,12 +13,11 @@
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import io
import json
import h5py
import kaldiio
import numpy as np
class CMVN():
"Apply Global/Spk CMVN/iverserCMVN."
@ -157,3 +156,37 @@ class UtteranceCMVN():
x = np.divide(x, std)
return x
class GlobalCMVN():
"Apply Global CMVN"
def __init__(self, cmvn_path, norm_means=True, norm_vars=True, std_floor=1.0e-20):
self.cmvn_path = cmvn_path
self.norm_means = norm_means
self.norm_vars = norm_vars
self.std_floor = std_floor
with open(cmvn_path) as f:
cmvn_stats = json.load(f)
self.count = cmvn_stats['frame_num']
self.mean = np.array(cmvn_stats['mean_stat']) / self.count
self.square_sums = np.array(cmvn_stats['var_stat'])
self.var = self.square_sums / self.count - self.mean**2
self.std = np.maximum(np.sqrt(self.var), self.std_floor)
def __repr__(self):
return f"""{self.__class__.__name__}(
cmvn_path={self.cmvn_path},
norm_means={self.norm_means},
norm_vars={self.norm_vars},)"""
def __call__(self, x, uttid=None):
# x: [Time, Dim]
if self.norm_means:
x = np.subtract(x, self.mean)
if self.norm_vars:
x = np.divide(x, self.std)
return x

@ -46,6 +46,7 @@ import_alias = dict(
wpe="paddlespeech.s2t.transform.wpe:WPE",
channel_selector="paddlespeech.s2t.transform.channel_selector:ChannelSelector",
fbank_kaldi="paddlespeech.s2t.transform.spectrogram:LogMelSpectrogramKaldi",
cmvn_json="paddlespeech.s2t.transform.cmvn:GlobalCMVN"
)

Loading…
Cancel
Save