# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict import paddle class InputNormalization: spk_dict_mean: Dict[int, paddle.Tensor] spk_dict_std: Dict[int, paddle.Tensor] spk_dict_count: Dict[int, int] def __init__( self, mean_norm=True, std_norm=True, norm_type="global", ): """Do feature or embedding mean and std norm Args: mean_norm (bool, optional): mean norm flag. Defaults to True. std_norm (bool, optional): std norm flag. Defaults to True. norm_type (str, optional): norm type. Defaults to "global". """ super().__init__() self.training = True self.mean_norm = mean_norm self.std_norm = std_norm self.norm_type = norm_type self.glob_mean = paddle.to_tensor([0], dtype="float32") self.glob_std = paddle.to_tensor([0], dtype="float32") self.spk_dict_mean = {} self.spk_dict_std = {} self.spk_dict_count = {} self.weight = 1.0 self.count = 0 self.eps = 1e-10 def __call__(self, x, lengths, spk_ids=paddle.to_tensor([], dtype="float32")): """Returns the tensor with the surrounding context. Args: x (paddle.Tensor): A batch of tensors. lengths (paddle.Tensor): A batch of tensors containing the relative length of each sentence (e.g, [0.7, 0.9, 1.0]). It is used to avoid computing stats on zero-padded steps. spk_ids (_type_, optional): tensor containing the ids of each speaker (e.g, [0 10 6]). It is used to perform per-speaker normalization when norm_type='speaker'. Defaults to paddle.to_tensor([], dtype="float32"). Returns: paddle.Tensor: The normalized feature or embedding """ N_batches = x.shape[0] # print(f"x shape: {x.shape[1]}") current_means = [] current_stds = [] for snt_id in range(N_batches): # Avoiding padded time steps # actual size is the actual time data length actual_size = paddle.round(lengths[snt_id] * x.shape[1]).astype("int32") # computing actual time data statistics current_mean, current_std = self._compute_current_stats( x[snt_id, 0:actual_size, ...].unsqueeze(0)) current_means.append(current_mean) current_stds.append(current_std) if self.norm_type == "global": current_mean = paddle.mean(paddle.stack(current_means), axis=0) current_std = paddle.mean(paddle.stack(current_stds), axis=0) if self.norm_type == "global": if self.training: if self.count == 0: self.glob_mean = current_mean self.glob_std = current_std else: self.weight = 1 / (self.count + 1) self.glob_mean = ( 1 - self.weight ) * self.glob_mean + self.weight * current_mean self.glob_std = ( 1 - self.weight ) * self.glob_std + self.weight * current_std self.glob_mean.detach() self.glob_std.detach() self.count = self.count + 1 x = (x - self.glob_mean) / (self.glob_std) return x def _compute_current_stats(self, x): """Returns the tensor with the surrounding context. Args: x (paddle.Tensor): A batch of tensors. Returns: the statistics of the data """ # Compute current mean if self.mean_norm: current_mean = paddle.mean(x, axis=0).detach() else: current_mean = paddle.to_tensor([0.0], dtype="float32") # Compute current std if self.std_norm: current_std = paddle.std(x, axis=0).detach() else: current_std = paddle.to_tensor([1.0], dtype="float32") # Improving numerical stability of std current_std = paddle.maximum(current_std, self.eps * paddle.ones_like(current_std)) return current_mean, current_std def _statistics_dict(self): """Fills the dictionary containing the normalization statistics. """ state = {} state["count"] = self.count state["glob_mean"] = self.glob_mean state["glob_std"] = self.glob_std state["spk_dict_mean"] = self.spk_dict_mean state["spk_dict_std"] = self.spk_dict_std state["spk_dict_count"] = self.spk_dict_count return state def _load_statistics_dict(self, state): """Loads the dictionary containing the statistics. Arguments --------- state : dict A dictionary containing the normalization statistics. """ self.count = state["count"] if isinstance(state["glob_mean"], int): self.glob_mean = state["glob_mean"] self.glob_std = state["glob_std"] else: self.glob_mean = state["glob_mean"] # .to(self.device_inp) self.glob_std = state["glob_std"] # .to(self.device_inp) # Loading the spk_dict_mean in the right device self.spk_dict_mean = {} for spk in state["spk_dict_mean"]: self.spk_dict_mean[spk] = state["spk_dict_mean"][spk] # Loading the spk_dict_std in the right device self.spk_dict_std = {} for spk in state["spk_dict_std"]: self.spk_dict_std[spk] = state["spk_dict_std"][spk] self.spk_dict_count = state["spk_dict_count"] return state def to(self, device): """Puts the needed tensors in the right device. """ self = super(InputNormalization, self).to(device) self.glob_mean = self.glob_mean.to(device) self.glob_std = self.glob_std.to(device) for spk in self.spk_dict_mean: self.spk_dict_mean[spk] = self.spk_dict_mean[spk].to(device) self.spk_dict_std[spk] = self.spk_dict_std[spk].to(device) return self def save(self, path): """Save statistic dictionary. Args: path (str): A path where to save the dictionary. """ stats = self._statistics_dict() paddle.save(stats, path) def _load(self, path, end_of_epoch=False, device=None): """Load statistic dictionary. Arguments --------- path : str The path of the statistic dictionary device : str, None Passed to paddle.load(..., map_location=device) """ del end_of_epoch # Unused here. stats = paddle.load(path, map_location=device) self._load_statistics_dict(stats)