|
|
|
@ -13,12 +13,11 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# Modified from espnet(https://github.com/espnet/espnet)
|
|
|
|
|
import io
|
|
|
|
|
|
|
|
|
|
import json
|
|
|
|
|
import h5py
|
|
|
|
|
import kaldiio
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CMVN():
|
|
|
|
|
"Apply Global/Spk CMVN/iverserCMVN."
|
|
|
|
|
|
|
|
|
@ -157,3 +156,37 @@ class UtteranceCMVN():
|
|
|
|
|
x = np.divide(x, std)
|
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GlobalCMVN():
|
|
|
|
|
"Apply Global CMVN"
|
|
|
|
|
|
|
|
|
|
def __init__(self, cmvn_path, norm_means=True, norm_vars=True, std_floor=1.0e-20):
|
|
|
|
|
self.cmvn_path = cmvn_path
|
|
|
|
|
self.norm_means = norm_means
|
|
|
|
|
self.norm_vars = norm_vars
|
|
|
|
|
self.std_floor = std_floor
|
|
|
|
|
|
|
|
|
|
with open(cmvn_path) as f:
|
|
|
|
|
cmvn_stats = json.load(f)
|
|
|
|
|
self.count = cmvn_stats['frame_num']
|
|
|
|
|
self.mean = np.array(cmvn_stats['mean_stat']) / self.count
|
|
|
|
|
self.square_sums = np.array(cmvn_stats['var_stat'])
|
|
|
|
|
self.var = self.square_sums / self.count - self.mean**2
|
|
|
|
|
self.std = np.maximum(np.sqrt(self.var), self.std_floor)
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
return f"""{self.__class__.__name__}(
|
|
|
|
|
cmvn_path={self.cmvn_path},
|
|
|
|
|
norm_means={self.norm_means},
|
|
|
|
|
norm_vars={self.norm_vars},)"""
|
|
|
|
|
|
|
|
|
|
def __call__(self, x, uttid=None):
|
|
|
|
|
# x: [Time, Dim]
|
|
|
|
|
if self.norm_means:
|
|
|
|
|
x = np.subtract(x, self.mean)
|
|
|
|
|
|
|
|
|
|
if self.norm_vars:
|
|
|
|
|
x = np.divide(x, self.std)
|
|
|
|
|
return x
|