diff --git a/examples/tiny/s1/conf/preprocess.yaml b/examples/tiny/s1/conf/preprocess.yaml index 9de0d8c7..dd4cfd27 100644 --- a/examples/tiny/s1/conf/preprocess.yaml +++ b/examples/tiny/s1/conf/preprocess.yaml @@ -1,22 +1,24 @@ process: # extract kaldi fbank from PCM - - type: "fbank_kaldi" + - type: fbank_kaldi fs: 16000 n_mels: 80 n_shift: 160 win_length: 400 dither: true + - type: cmvn_json + cmvn_path: data/mean_std.json # these three processes are a.k.a. SpecAugument - - type: "time_warp" + - type: time_warp max_time_warp: 5 inplace: true - mode: "PIL" - - type: "freq_mask" + mode: PIL + - type: freq_mask F: 30 n_mask: 2 inplace: true replace_with_zero: false - - type: "time_mask" + - type: time_mask T: 40 n_mask: 2 inplace: true diff --git a/examples/tiny/s1/conf/transformer.yaml b/examples/tiny/s1/conf/transformer.yaml index 87f9c243..1378e848 100644 --- a/examples/tiny/s1/conf/transformer.yaml +++ b/examples/tiny/s1/conf/transformer.yaml @@ -11,7 +11,7 @@ data: max_output_input_ratio: 10.0 collator: - mean_std_filepath: "" + mean_std_filepath: data/mean_std.json vocab_filepath: data/vocab.txt unit_type: 'spm' spm_model_prefix: 'data/bpe_unigram_200' @@ -37,7 +37,7 @@ collator: # network architecture model: - cmvn_file: "data/mean_std.json" + cmvn_file: cmvn_file_type: "json" # encoder related encoder: transformer diff --git a/paddlespeech/s2t/transform/cmvn.py b/paddlespeech/s2t/transform/cmvn.py index 4d2d2324..dc9ea87e 100644 --- a/paddlespeech/s2t/transform/cmvn.py +++ b/paddlespeech/s2t/transform/cmvn.py @@ -13,12 +13,11 @@ # limitations under the License. # Modified from espnet(https://github.com/espnet/espnet) import io - +import json import h5py import kaldiio import numpy as np - class CMVN(): "Apply Global/Spk CMVN/iverserCMVN." @@ -157,3 +156,37 @@ class UtteranceCMVN(): x = np.divide(x, std) return x + + + +class GlobalCMVN(): + "Apply Global CMVN" + + def __init__(self, cmvn_path, norm_means=True, norm_vars=True, std_floor=1.0e-20): + self.cmvn_path = cmvn_path + self.norm_means = norm_means + self.norm_vars = norm_vars + self.std_floor = std_floor + + with open(cmvn_path) as f: + cmvn_stats = json.load(f) + self.count = cmvn_stats['frame_num'] + self.mean = np.array(cmvn_stats['mean_stat']) / self.count + self.square_sums = np.array(cmvn_stats['var_stat']) + self.var = self.square_sums / self.count - self.mean**2 + self.std = np.maximum(np.sqrt(self.var), self.std_floor) + + def __repr__(self): + return f"""{self.__class__.__name__}( + cmvn_path={self.cmvn_path}, + norm_means={self.norm_means}, + norm_vars={self.norm_vars},)""" + + def __call__(self, x, uttid=None): + # x: [Time, Dim] + if self.norm_means: + x = np.subtract(x, self.mean) + + if self.norm_vars: + x = np.divide(x, self.std) + return x \ No newline at end of file diff --git a/paddlespeech/s2t/transform/transformation.py b/paddlespeech/s2t/transform/transformation.py index 492d35df..bfe6c53d 100644 --- a/paddlespeech/s2t/transform/transformation.py +++ b/paddlespeech/s2t/transform/transformation.py @@ -46,6 +46,7 @@ import_alias = dict( wpe="paddlespeech.s2t.transform.wpe:WPE", channel_selector="paddlespeech.s2t.transform.channel_selector:ChannelSelector", fbank_kaldi="paddlespeech.s2t.transform.spectrogram:LogMelSpectrogramKaldi", + cmvn_json="paddlespeech.s2t.transform.cmvn:GlobalCMVN" )