diff --git a/.flake8 b/.flake8 index 6b50de7e..ae15ad2b 100644 --- a/.flake8 +++ b/.flake8 @@ -33,7 +33,7 @@ filename = # Specify a list of codes to ignore. ignore = W503 - E252,E262,E127,E265,E126,E266,E241,E261,E128,E125 + E252,E262,E127,E265,E126,E266,E241,E261,E128,E125,E129 W291,W293,W605 E203,E305,E402,E501,E721,E741,F403,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303, # shebang has extra meaning in fbcode lints, so I think it's not worth trying diff --git a/examples/librispeech/README.md b/examples/librispeech/README.md index 74441fd0..9fcbde97 100644 --- a/examples/librispeech/README.md +++ b/examples/librispeech/README.md @@ -3,7 +3,7 @@ * asr0 - deepspeech2 Streaming/Non-Streaming * asr1 - transformer/conformer Streaming/Non-Streaming * asr2 - transformer/conformer Streaming/Non-Streaming with Kaldi feature - +* asr3 - wav2vecASR, ASR model with pre-trained wav2vec2 and CTC ## Data | Data Subset | Duration in Seconds | diff --git a/paddlespeech/audio/transform/spectrogram.py b/paddlespeech/audio/transform/spectrogram.py index 864f3f99..2e519939 100644 --- a/paddlespeech/audio/transform/spectrogram.py +++ b/paddlespeech/audio/transform/spectrogram.py @@ -382,6 +382,36 @@ class LogMelSpectrogramKaldi(): return mat +class WavProcess(): + def __init__(self, dither=0.1): + """ + Args: + dither (float): Dithering constant + + Returns: + """ + + self.dither = dither + + def __call__(self, x, train): + """ + Args: + x (np.ndarray): shape (Ti,) + train (bool): True, train mode. + + Raises: + ValueError: not support (Ti, C) + + Returns: + np.ndarray: (T, D) + """ + dither = self.dither if train else 0.0 + if x.ndim != 1: + raise ValueError("Not support x: [Time, Channel]") + waveform = np.expand_dims(x, -1) + return waveform + + class LogMelSpectrogramKaldi_decay(): def __init__( self, diff --git a/paddlespeech/audio/transform/transformation.py b/paddlespeech/audio/transform/transformation.py index d24d6437..e2f66dbf 100644 --- a/paddlespeech/audio/transform/transformation.py +++ b/paddlespeech/audio/transform/transformation.py @@ -41,6 +41,7 @@ import_alias = dict( utterance_cmvn="paddlespeech.audio.transform.cmvn:UtteranceCMVN", fbank="paddlespeech.audio.transform.spectrogram:LogMelSpectrogram", spectrogram="paddlespeech.audio.transform.spectrogram:Spectrogram", + wav_process="paddlespeech.audio.transform.spectrogram:WavProcess", stft="paddlespeech.audio.transform.spectrogram:Stft", istft="paddlespeech.audio.transform.spectrogram:IStft", stft2fbank="paddlespeech.audio.transform.spectrogram:Stft2LogMelSpectrogram", diff --git a/paddlespeech/s2t/exps/wav2vec2/bin/test_wav.py b/paddlespeech/s2t/exps/wav2vec2/bin/test_wav.py index 5306d7f8..3a537bce 100644 --- a/paddlespeech/s2t/exps/wav2vec2/bin/test_wav.py +++ b/paddlespeech/s2t/exps/wav2vec2/bin/test_wav.py @@ -27,6 +27,7 @@ from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.utility import UpdateConfig logger = Log(__name__).getlog() + class Wav2vec2Infer(): def __init__(self, config, args): self.args = args @@ -34,8 +35,7 @@ class Wav2vec2Infer(): self.audio_file = args.audio_file self.text_feature = TextFeaturizer( - unit_type=config.unit_type, - vocab=config.vocab_filepath) + unit_type=config.unit_type, vocab=config.vocab_filepath) paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu') # model @@ -63,10 +63,10 @@ class Wav2vec2Infer(): xs = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0) decode_config = self.config.decode result_transcripts, result_tokenids = self.model.decode( - xs, - text_feature=self.text_feature, - decoding_method=decode_config.decoding_method, - beam_size=decode_config.beam_size) + xs, + text_feature=self.text_feature, + decoding_method=decode_config.decoding_method, + beam_size=decode_config.beam_size) rsl = result_transcripts[0] utt = Path(self.audio_file).name logger.info(f"hyp: {utt} {rsl}") diff --git a/paddlespeech/s2t/exps/wav2vec2/model.py b/paddlespeech/s2t/exps/wav2vec2/model.py index 3d9c266e..32cf0b47 100644 --- a/paddlespeech/s2t/exps/wav2vec2/model.py +++ b/paddlespeech/s2t/exps/wav2vec2/model.py @@ -18,53 +18,53 @@ import time from collections import defaultdict from collections import OrderedDict from contextlib import nullcontext -from paddlespeech.s2t.utils import mp_tools import jsonlines import numpy as np import paddle from paddle import distributed as dist + from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.io.dataloader import BatchDataLoader -from paddlespeech.s2t.io.dataloader import StreamDataLoader from paddlespeech.s2t.io.dataloader import DataLoaderFactory -from paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR import Wav2vec2ASR +from paddlespeech.s2t.io.dataloader import StreamDataLoader from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugment -from paddlespeech.s2t.utils import error_rate - +from paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR import Wav2vec2ASR from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.reporter import ObsScope from paddlespeech.s2t.training.reporter import report from paddlespeech.s2t.training.scheduler import LRSchedulerFactory from paddlespeech.s2t.training.timer import Timer from paddlespeech.s2t.training.trainer import Trainer -from paddlespeech.s2t.utils.utility import UpdateConfig +from paddlespeech.s2t.utils import error_rate from paddlespeech.s2t.utils import layer_tools +from paddlespeech.s2t.utils import mp_tools from paddlespeech.s2t.utils.log import Log - - +from paddlespeech.s2t.utils.utility import UpdateConfig logger = Log(__name__).getlog() + class Wav2Vec2ASRTrainer(Trainer): def __init__(self, config, args): super().__init__(config, args) self.avg_train_loss = 0 + def train_batch(self, batch_index, batch, msg): train_conf = self.config start = time.time() # forward utt, wav, wavs_lens, target, target_lens = batch - wavs_lens_rate = wavs_lens / wav.shape[1] + wavs_lens_rate = wavs_lens / wav.shape[1] target_lens_rate = target_lens / target.shape[1] - wav = wav[:,:,0] + wav = wav[:, :, 0] wav = self.speech_augmentation(wav, wavs_lens_rate) loss = self.model(wav, wavs_lens_rate, target, target_lens_rate) # pring(wav, wavs_lens_rate, target, target_lens_rate) # loss div by `batch_size * accum_grad` loss /= train_conf.accum_grad - + losses_np = {'loss': float(loss) * train_conf.accum_grad} # loss backward @@ -108,15 +108,16 @@ class Wav2Vec2ASRTrainer(Trainer): def valid(self): self.model.eval() if not self.use_streamdata: - logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") + logger.info( + f"Valid Total Examples: {len(self.valid_loader.dataset)}") valid_losses = defaultdict(list) num_seen_utts = 1 total_loss = 0.0 for i, batch in enumerate(self.valid_loader): utt, wav, wavs_lens, target, target_lens = batch - wavs_lens_rate = wavs_lens / wav.shape[1] + wavs_lens_rate = wavs_lens / wav.shape[1] target_lens_rate = target_lens / target.shape[1] - wav = wav[:,:,0] + wav = wav[:, :, 0] loss = self.model(wav, wavs_lens_rate, target, target_lens_rate) if paddle.isfinite(loss): @@ -134,7 +135,8 @@ class Wav2Vec2ASRTrainer(Trainer): msg += "epoch: {}, ".format(self.epoch) msg += "step: {}, ".format(self.iteration) if not self.use_streamdata: - msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) + msg += "batch: {}/{}, ".format(i + 1, + len(self.valid_loader)) msg += ', '.join('{}: {:>.6f}'.format(k, v) for k, v in valid_dump.items()) logger.info(msg) @@ -155,7 +157,8 @@ class Wav2Vec2ASRTrainer(Trainer): self.before_train() if not self.use_streamdata: - logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") + logger.info( + f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.n_epoch: with Timer("Epoch-Train Time Cost: {}"): self.model.train() @@ -223,14 +226,18 @@ class Wav2Vec2ASRTrainer(Trainer): config = self.config.clone() self.use_streamdata = config.get("use_stream_data", False) if self.train: - self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args) - self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args) + self.train_loader = DataLoaderFactory.get_dataloader( + 'train', config, self.args) + self.valid_loader = DataLoaderFactory.get_dataloader( + 'valid', config, self.args) logger.info("Setup train/valid Dataloader!") else: decode_batch_size = config.get('decode', dict()).get( 'decode_batch_size', 1) - self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args) - self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args) + self.test_loader = DataLoaderFactory.get_dataloader('test', config, + self.args) + self.align_loader = DataLoaderFactory.get_dataloader( + 'align', config, self.args) logger.info("Setup test/align Dataloader!") def setup_model(self): @@ -248,7 +255,7 @@ class Wav2Vec2ASRTrainer(Trainer): model = Wav2vec2ASR.from_config(model_conf) if self.parallel: - model = paddle.DataParallel(model, find_unused_parameters=True) + model = paddle.DataParallel(model, find_unused_parameters=True) logger.info(f"{model}") layer_tools.print_params(model, logger.info) @@ -312,14 +319,14 @@ class Wav2Vec2ASRTester(Wav2Vec2ASRTrainer): self.text_featurizer = TextFeaturizer( unit_type=config.unit_type, vocab=config.vocab_filepath) self.vocab_list = self.text_featurizer.vocab_list + def id2token(self, texts, texts_len): """ ord() id to chr() chr """ trans = [] for text, n in zip(texts, texts_len): n = n.numpy().item() ids = text[:n] - trans.append( - self.text_featurizer.defeaturize(ids.numpy().tolist())) + trans.append(self.text_featurizer.defeaturize(ids.numpy().tolist())) return trans def compute_metrics(self, @@ -337,10 +344,10 @@ class Wav2Vec2ASRTester(Wav2Vec2ASRTrainer): start_time = time.time() target_transcripts = self.id2token(texts, texts_len) result_transcripts, result_tokenids = self.model.decode( - audio, - text_feature=self.text_featurizer, - decoding_method=decode_cfg.decoding_method, - beam_size=decode_cfg.beam_size) + audio, + text_feature=self.text_featurizer, + decoding_method=decode_cfg.decoding_method, + beam_size=decode_cfg.beam_size) decode_time = time.time() - start_time for utt, target, result, rec_tids in zip( @@ -432,4 +439,4 @@ class Wav2Vec2ASRTester(Wav2Vec2ASRTrainer): "decode_method": self.config.decode.decoding_method, }) - f.write(data + '\n') \ No newline at end of file + f.write(data + '\n') diff --git a/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py b/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py index a8f5f5cb..ae141d1b 100644 --- a/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py +++ b/paddlespeech/s2t/models/wav2vec2/modules/VanillaNN.py @@ -3,6 +3,7 @@ Authors * Elena Rastorgueva 2020 """ import paddle + from paddlespeech.s2t.models.wav2vec2.modules import containers from paddlespeech.s2t.models.wav2vec2.modules import linear @@ -27,12 +28,11 @@ class VanillaNN(containers.Sequential): """ def __init__( - self, - input_shape, - activation=paddle.nn.LeakyReLU, - dnn_blocks=2, - dnn_neurons=512, - ): + self, + input_shape, + activation=paddle.nn.LeakyReLU, + dnn_blocks=2, + dnn_neurons=512, ): super().__init__(input_shape=input_shape) for block_index in range(dnn_blocks): @@ -40,6 +40,5 @@ class VanillaNN(containers.Sequential): linear.Linear, n_neurons=dnn_neurons, bias=True, - layer_name="linear", - ) + layer_name="linear", ) self.append(activation(), layer_name="act") diff --git a/paddlespeech/s2t/models/wav2vec2/modules/activations.py b/paddlespeech/s2t/models/wav2vec2/modules/activations.py index 9df652c2..722d8a0d 100644 --- a/paddlespeech/s2t/models/wav2vec2/modules/activations.py +++ b/paddlespeech/s2t/models/wav2vec2/modules/activations.py @@ -11,12 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import math -from packaging import version -from paddle import Tensor, nn - +from paddle import nn +from paddle import Tensor from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() @@ -29,7 +27,9 @@ class NewGELUActivation(nn.Layer): """ def forward(self, input: Tensor) -> Tensor: - return 0.5 * input * (1.0 + paddle.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * paddle.pow(input, 3.0)))) + return 0.5 * input * (1.0 + paddle.tanh( + math.sqrt(2.0 / math.pi) * + (input + 0.044715 * paddle.pow(input, 3.0)))) class GELUActivation(nn.Layer): @@ -40,7 +40,7 @@ class GELUActivation(nn.Layer): Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 """ - def __init__(self, use_gelu_python: bool = False): + def __init__(self, use_gelu_python: bool=False): super().__init__() self.act = nn.functional.gelu @@ -57,7 +57,9 @@ class FastGELUActivation(nn.Layer): """ def forward(self, input: Tensor) -> Tensor: - return 0.5 * input * (1.0 + paddle.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) + return 0.5 * input * ( + 1.0 + paddle.tanh(input * 0.7978845608 * + (1.0 + 0.044715 * input * input))) class QuickGELUActivation(nn.Layer): @@ -84,7 +86,8 @@ class ClippedGELUActivation(nn.Layer): def __init__(self, min: float, max: float): if min > max: - raise ValueError(f"min should be < max (got min: {min}, max: {max})") + raise ValueError( + f"min should be < max (got min: {min}, max: {max})") super().__init__() self.min = min @@ -161,7 +164,9 @@ def get_activation(activation_string): if activation_string in ACT2FN: return ACT2FN[activation_string] else: - raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}") + raise KeyError( + f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}" + ) # For backwards compatibility with: from activations import gelu_python diff --git a/paddlespeech/s2t/models/wav2vec2/modules/containers.py b/paddlespeech/s2t/models/wav2vec2/modules/containers.py index 2b961a59..b3973357 100644 --- a/paddlespeech/s2t/models/wav2vec2/modules/containers.py +++ b/paddlespeech/s2t/models/wav2vec2/modules/containers.py @@ -1,8 +1,7 @@ -import paddle import inspect -import logging -import operator -import functools + +import paddle + class Sequential(paddle.nn.LayerDict): """A sequence of modules with potentially inferring shape on construction. @@ -98,13 +97,12 @@ class Sequential(paddle.nn.LayerDict): # Finally, append the layer. try: self[layer_name] = layer - # self.add_module(layer_name, layer) + # self.add_module(layer_name, layer) except TypeError: raise ValueError( "Must pass `input_shape` at initialization and use " "modules that take `input_shape` to infer shape when " - "using `append()`." - ) + "using `append()`.") def get_output_shape(self): """Returns expected shape of the output. diff --git a/paddlespeech/s2t/models/wav2vec2/modules/linear.py b/paddlespeech/s2t/models/wav2vec2/modules/linear.py index 26389d90..488949d1 100644 --- a/paddlespeech/s2t/models/wav2vec2/modules/linear.py +++ b/paddlespeech/s2t/models/wav2vec2/modules/linear.py @@ -3,10 +3,10 @@ Authors * Mirco Ravanelli 2020 * Davide Borra 2021 """ - import logging + import paddle -import paddle.nn as nn + from paddlespeech.s2t.modules import align logger = logging.getLogger(__name__) @@ -37,13 +37,12 @@ class Linear(paddle.nn.Layer): """ def __init__( - self, - n_neurons, - input_shape=None, - input_size=None, - bias=True, - combine_dims=False, - ): + self, + n_neurons, + input_shape=None, + input_size=None, + bias=True, + combine_dims=False, ): super().__init__() self.combine_dims = combine_dims diff --git a/paddlespeech/s2t/models/wav2vec2/modules/modeling_outputs.py b/paddlespeech/s2t/models/wav2vec2/modules/modeling_outputs.py index a5b509b6..fb2a8712 100644 --- a/paddlespeech/s2t/models/wav2vec2/modules/modeling_outputs.py +++ b/paddlespeech/s2t/models/wav2vec2/modules/modeling_outputs.py @@ -11,12 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from dataclasses import dataclass -from typing import Optional, Tuple from collections import OrderedDict - +from dataclasses import dataclass from dataclasses import fields +from typing import Optional +from typing import Tuple + import paddle @@ -41,10 +41,13 @@ class ModelOutput(OrderedDict): if not len(class_fields): raise ValueError(f"{self.__class__.__name__} has no fields.") if not all(field.default is None for field in class_fields[1:]): - raise ValueError(f"{self.__class__.__name__} should not have more than one required field.") + raise ValueError( + f"{self.__class__.__name__} should not have more than one required field." + ) first_field = getattr(self, class_fields[0].name) - other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) + other_fields_are_none = all( + getattr(self, field.name) is None for field in class_fields[1:]) if other_fields_are_none and not paddle.is_tensor(first_field): if isinstance(first_field, dict): @@ -61,11 +64,9 @@ class ModelOutput(OrderedDict): # set the associated fields if first_field_iterator: for element in iterator: - if ( - not isinstance(element, (list, tuple)) - or not len(element) == 2 - or not isinstance(element[0], str) - ): + if (not isinstance(element, (list, tuple)) or + not len(element) == 2 or + not isinstance(element[0], str)): break setattr(self, element[0], element[1]) if element[1] is not None: @@ -79,16 +80,23 @@ class ModelOutput(OrderedDict): self[field.name] = v def __delitem__(self, *args, **kwargs): - raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") + raise Exception( + f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance." + ) def setdefault(self, *args, **kwargs): - raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") + raise Exception( + f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance." + ) def pop(self, *args, **kwargs): - raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") + raise Exception( + f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") def update(self, *args, **kwargs): - raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") + raise Exception( + f"You cannot use ``update`` on a {self.__class__.__name__} instance." + ) def __getitem__(self, k): if isinstance(k, str): diff --git a/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py b/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py index 6988aa6a..3d5e5fa6 100644 --- a/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py +++ b/paddlespeech/s2t/models/wav2vec2/modules/modeling_wav2vec2.py @@ -13,24 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Paddle Wav2Vec2 model.""" - -import math -import warnings -import paddle from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Optional +from typing import Tuple +from typing import Union import numpy as np +import paddle from paddle import nn from paddlespeech.s2t.models.wav2vec2.modules.activations import ACT2FN -from paddlespeech.s2t.models.wav2vec2.modules.modeling_outputs import ( - BaseModelOutput, - Wav2Vec2BaseModelOutput, - ModelOutput -) -import pdb - +from paddlespeech.s2t.models.wav2vec2.modules.modeling_outputs import BaseModelOutput +from paddlespeech.s2t.models.wav2vec2.modules.modeling_outputs import ModelOutput +from paddlespeech.s2t.models.wav2vec2.modules.modeling_outputs import Wav2Vec2BaseModelOutput from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() @@ -78,12 +73,11 @@ class Wav2Vec2ForPreTrainingOutput(ModelOutput): def _compute_mask_indices( - shape: Tuple[int, int], - mask_prob: float, - mask_length: int, - attention_mask: Optional[paddle.Tensor] = None, - min_masks: int = 0, -) -> np.ndarray: + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[paddle.Tensor]=None, + min_masks: int=0, ) -> np.ndarray: """ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on @@ -109,8 +103,7 @@ def _compute_mask_indices( if mask_length > sequence_length: raise ValueError( f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" - f" and `sequence_length`: {sequence_length}`" - ) + f" and `sequence_length`: {sequence_length}`") # epsilon is used for probabilistic rounding epsilon = np.random.rand(1).item() @@ -131,11 +124,9 @@ def _compute_mask_indices( return num_masked_span # compute number of masked spans in batch - input_lengths = ( - attention_mask.sum(-1).detach().tolist() - if attention_mask is not None - else [sequence_length for _ in range(batch_size)] - ) + input_lengths = (attention_mask.sum(-1).detach().tolist() + if attention_mask is not None else + [sequence_length for _ in range(batch_size)]) # SpecAugment mask to fill spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=np.bool) @@ -152,8 +143,9 @@ def _compute_mask_indices( # get random indices to mask spec_aug_mask_idx = np.random.choice( - np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False - ) + np.arange(input_length - (mask_length - 1)), + num_masked_span, + replace=False) # pick first sampled index that will serve as a dummy index to pad vector # to ensure same dimension for all batches due to probabilistic rounding @@ -166,29 +158,33 @@ def _compute_mask_indices( else: dummy_mask_idx = spec_aug_mask_idx[0] - spec_aug_mask_idx = np.concatenate( - [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] - ) + spec_aug_mask_idx = np.concatenate([ + spec_aug_mask_idx, + np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * + dummy_mask_idx + ]) spec_aug_mask_idxs.append(spec_aug_mask_idx) spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) # expand masked indices to masked spans spec_aug_mask_idxs = np.broadcast_to( - spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) - ) - spec_aug_mask_idxs = spec_aug_mask_idxs.reshape((batch_size, max_num_masked_span * mask_length)) + spec_aug_mask_idxs[:, :, None], + (batch_size, max_num_masked_span, mask_length)) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape( + (batch_size, max_num_masked_span * mask_length)) # add offset to the starting indexes so that indexes now create a span offsets = np.arange(mask_length)[None, None, :] - offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( - (batch_size, max_num_masked_span * mask_length) - ) + offsets = np.broadcast_to(offsets, ( + batch_size, max_num_masked_span, mask_length)).reshape( + (batch_size, max_num_masked_span * mask_length)) spec_aug_mask_idxs = spec_aug_mask_idxs + offsets # ensure that we cannot have indices larger than sequence_length if spec_aug_mask_idxs.max() > sequence_length - 1: - spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + spec_aug_mask_idxs[spec_aug_mask_idxs > + sequence_length - 1] = sequence_length - 1 # scatter indices to mask np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) @@ -196,9 +192,9 @@ def _compute_mask_indices( return spec_aug_mask -def _sample_negative_indices( - features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None -): +def _sample_negative_indices(features_shape: Tuple, + num_negatives: int, + mask_time_indices: Optional[np.ndarray]=None): """ Sample `num_negatives` vectors from feature vectors. """ @@ -208,23 +204,28 @@ def _sample_negative_indices( sequence_length_range = np.arange(sequence_length) # get `num_negatives` random vector indices from the same utterance - sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32) + sampled_negative_indices = np.zeros( + shape=(batch_size, sequence_length, num_negatives), dtype=np.int32) - mask_time_indices = ( - mask_time_indices.astype(np.bool) if mask_time_indices is not None else np.ones(features_shape, dtype=np.bool) - ) + mask_time_indices = (mask_time_indices.astype(np.bool) + if mask_time_indices is not None else + np.ones(features_shape, dtype=np.bool)) for batch_idx in range(batch_size): high = mask_time_indices[batch_idx].sum() - 1 - mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]] + mapped_masked_indices = sequence_length_range[mask_time_indices[ + batch_idx]] - feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives)) - sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives)) + feature_indices = np.broadcast_to( + np.arange(high + 1)[:, None], (high + 1, num_negatives)) + sampled_indices = np.random.randint( + 0, high, size=(high + 1, num_negatives)) # avoid sampling the same positive vector, but keep the distribution uniform sampled_indices[sampled_indices >= feature_indices] += 1 # remap to actual indices - sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices] + sampled_negative_indices[batch_idx][mask_time_indices[ + batch_idx]] = mapped_masked_indices[sampled_indices] # correct for batch size sampled_negative_indices[batch_idx] += batch_idx * sequence_length @@ -243,8 +244,7 @@ class Wav2Vec2NoLayerNormConvLayer(nn.Layer): self.out_conv_dim, kernel_size=config.conv_kernel[layer_id], stride=config.conv_stride[layer_id], - bias_attr=config.conv_bias, - ) + bias_attr=config.conv_bias, ) self.activation = ACT2FN[config.feat_extract_activation] def forward(self, hidden_states): @@ -264,8 +264,7 @@ class Wav2Vec2LayerNormConvLayer(nn.Layer): self.out_conv_dim, kernel_size=config.conv_kernel[layer_id], stride=config.conv_stride[layer_id], - bias_attr=config.conv_bias, - ) + bias_attr=config.conv_bias, ) self.layer_norm = nn.LayerNorm(self.out_conv_dim) self.activation = ACT2FN[config.feat_extract_activation] @@ -290,11 +289,11 @@ class Wav2Vec2GroupNormConvLayer(nn.Layer): self.out_conv_dim, kernel_size=config.conv_kernel[layer_id], stride=config.conv_stride[layer_id], - bias_attr=config.conv_bias, - ) + bias_attr=config.conv_bias, ) self.activation = ACT2FN[config.feat_extract_activation] - self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim) + self.layer_norm = nn.GroupNorm( + num_groups=self.out_conv_dim, num_channels=self.out_conv_dim) def forward(self, hidden_states): hidden_states = self.conv(hidden_states) @@ -311,8 +310,7 @@ class Wav2Vec2PositionalConvEmbedding(nn.Layer): config.hidden_size, kernel_size=config.num_conv_pos_embeddings, padding=config.num_conv_pos_embeddings // 2, - groups=config.num_conv_pos_embedding_groups, - ) + groups=config.num_conv_pos_embedding_groups, ) self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) @@ -337,7 +335,7 @@ class Wav2Vec2SamePadLayer(nn.Layer): def forward(self, hidden_states): if self.num_pad_remove > 0: - hidden_states = hidden_states[:, :, : -self.num_pad_remove] + hidden_states = hidden_states[:, :, :-self.num_pad_remove] return hidden_states @@ -349,11 +347,13 @@ class Wav2Vec2FeatureEncoder(nn.Layer): if config.feat_extract_norm == "group": conv_layers = [Wav2Vec2GroupNormConvLayer(config, layer_id=0)] + [ - Wav2Vec2NoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1) + Wav2Vec2NoLayerNormConvLayer(config, layer_id=i + 1) + for i in range(config.num_feat_extract_layers - 1) ] elif config.feat_extract_norm == "layer": conv_layers = [ - Wav2Vec2LayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers) + Wav2Vec2LayerNormConvLayer(config, layer_id=i) + for i in range(config.num_feat_extract_layers) ] else: raise ValueError( @@ -373,10 +373,12 @@ class Wav2Vec2FeatureEncoder(nn.Layer): return hidden_states + class Wav2Vec2FeatureProjection(nn.Layer): def __init__(self, config): super().__init__() - self.layer_norm = nn.LayerNorm(config.conv_dim[-1], epsilon=config.layer_norm_eps) + self.layer_norm = nn.LayerNorm( + config.conv_dim[-1], epsilon=config.layer_norm_eps) self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size) self.dropout = nn.Dropout(config.feat_proj_dropout) @@ -393,13 +395,12 @@ class Wav2Vec2Attention(nn.Layer): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - is_decoder: bool = False, - bias: bool = True, - ): + self, + embed_dim: int, + num_heads: int, + dropout: float=0.0, + is_decoder: bool=False, + bias: bool=True, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads @@ -409,8 +410,7 @@ class Wav2Vec2Attention(nn.Layer): if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads})." - ) + f" and `num_heads`: {num_heads}).") self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder @@ -420,17 +420,18 @@ class Wav2Vec2Attention(nn.Layer): self.out_proj = nn.Linear(embed_dim, embed_dim, bias_attr=bias) def _shape(self, tensor: paddle.Tensor, seq_len: int, bsz: int): - return paddle.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)).transpose([0, 2, 1, 3]) + return paddle.reshape(tensor, (bsz, seq_len, self.num_heads, + self.head_dim)).transpose([0, 2, 1, 3]) def forward( - self, - hidden_states: paddle.Tensor, - key_value_states: Optional[paddle.Tensor] = None, - past_key_value: Optional[Tuple[paddle.Tensor]] = None, - attention_mask: Optional[paddle.Tensor] = None, - layer_head_mask: Optional[paddle.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: + self, + hidden_states: paddle.Tensor, + key_value_states: Optional[paddle.Tensor]=None, + past_key_value: Optional[Tuple[paddle.Tensor]]=None, + attention_mask: Optional[paddle.Tensor]=None, + layer_head_mask: Optional[paddle.Tensor]=None, + output_attentions: bool=False, ) -> Tuple[paddle.Tensor, Optional[ + paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer @@ -455,7 +456,8 @@ class Wav2Vec2Attention(nn.Layer): key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) key_states = paddle.concat([past_key_value[0], key_states], axis=2) - value_states = paddle.concat([past_key_value[1], value_states], axis=2) + value_states = paddle.concat( + [past_key_value[1], value_states], axis=2) else: # self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) @@ -472,60 +474,68 @@ class Wav2Vec2Attention(nn.Layer): past_key_value = (key_states, value_states) proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).reshape(proj_shape) + query_states = self._shape(query_states, tgt_len, + bsz).reshape(proj_shape) key_states = key_states.reshape(proj_shape) value_states = value_states.reshape(proj_shape) src_len = key_states.shape[1] attn_weights = paddle.bmm(query_states, key_states.transpose([0, 2, 1])) - - + if attn_weights.shape != [bsz * self.num_heads, tgt_len, src_len]: raise ValueError( f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.shape}" - ) + f" {attn_weights.shape}") if attention_mask is not None: if attention_mask.shape != [bsz, 1, tgt_len, src_len]: raise ValueError( f"Attention mask should be of size {[bsz, 1, tgt_len, src_len]}, but is {attention_mask.shape}" ) - attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, + src_len) + attention_mask + attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, + src_len) - attn_weights = nn.functional.softmax(attn_weights, axis=- 1) + attn_weights = nn.functional.softmax(attn_weights, axis=-1) if layer_head_mask is not None: - if layer_head_mask.shape != [self.num_heads,]: + if layer_head_mask.shape != [ + self.num_heads, + ]: raise ValueError( f"Head mask for a single layer should be of size {[self.num_heads,]}, but is" - f" {layer_head_mask.shape}" - ) - attn_weights = layer_head_mask.reshape((1, -1, 1, 1)) * attn_weights.reshape((bsz, self.num_heads, tgt_len, src_len)) - attn_weights = attn_weights.reshape((bsz * self.num_heads, tgt_len, src_len)) + f" {layer_head_mask.shape}") + attn_weights = layer_head_mask.reshape( + (1, -1, 1, 1)) * attn_weights.reshape( + (bsz, self.num_heads, tgt_len, src_len)) + attn_weights = attn_weights.reshape( + (bsz * self.num_heads, tgt_len, src_len)) if output_attentions: # this operation is a bit awkward, but it's required to # make sure that attn_weights keeps its gradient. # In order to do so, attn_weights have to be reshaped # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.reshape((bsz, self.num_heads, tgt_len, src_len)) - attn_weights = attn_weights_reshaped.reshape((bsz * self.num_heads, tgt_len, src_len)) + attn_weights_reshaped = attn_weights.reshape( + (bsz, self.num_heads, tgt_len, src_len)) + attn_weights = attn_weights_reshaped.reshape( + (bsz * self.num_heads, tgt_len, src_len)) else: attn_weights_reshaped = None - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_probs = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training) attn_output = paddle.bmm(attn_probs, value_states) if attn_output.shape != [bsz * self.num_heads, tgt_len, self.head_dim]: raise ValueError( f"`attn_output` should be of size {[bsz, self.num_heads, tgt_len, self.head_dim]}, but is" - f" {attn_output.shape}" - ) + f" {attn_output.shape}") - attn_output = attn_output.reshape((bsz, self.num_heads, tgt_len, self.head_dim)) + attn_output = attn_output.reshape( + (bsz, self.num_heads, tgt_len, self.head_dim)) attn_output = attn_output.transpose([0, 2, 1, 3]) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be @@ -542,13 +552,15 @@ class Wav2Vec2FeedForward(nn.Layer): super().__init__() self.intermediate_dropout = nn.Dropout(config.activation_dropout) - self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.intermediate_dense = nn.Linear(config.hidden_size, + config.intermediate_size) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act - self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.output_dense = nn.Linear(config.intermediate_size, + config.hidden_size) self.output_dropout = nn.Dropout(config.hidden_dropout) def forward(self, hidden_states): @@ -568,18 +580,23 @@ class Wav2Vec2EncoderLayer(nn.Layer): embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, - is_decoder=False, - ) + is_decoder=False, ) self.dropout = nn.Dropout(config.hidden_dropout) - self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) + self.layer_norm = nn.LayerNorm( + config.hidden_size, epsilon=config.layer_norm_eps) self.feed_forward = Wav2Vec2FeedForward(config) - self.final_layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) + self.final_layer_norm = nn.LayerNorm( + config.hidden_size, epsilon=config.layer_norm_eps) - def forward(self, hidden_states, attention_mask=None, output_attentions=False): + def forward(self, + hidden_states, + attention_mask=None, + output_attentions=False): attn_residual = hidden_states hidden_states, attn_weights, _ = self.attention( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions) hidden_states = self.dropout(hidden_states) hidden_states = attn_residual + hidden_states @@ -587,10 +604,10 @@ class Wav2Vec2EncoderLayer(nn.Layer): hidden_states = hidden_states + self.feed_forward(hidden_states) hidden_states = self.final_layer_norm(hidden_states) - outputs = (hidden_states,) + outputs = (hidden_states, ) if output_attentions: - outputs += (attn_weights,) + outputs += (attn_weights, ) return outputs @@ -602,27 +619,33 @@ class Wav2Vec2EncoderLayerStableLayerNorm(nn.Layer): embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, - is_decoder=False, - ) + is_decoder=False, ) self.dropout = nn.Dropout(config.hidden_dropout) - self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) + self.layer_norm = nn.LayerNorm( + config.hidden_size, epsilon=config.layer_norm_eps) self.feed_forward = Wav2Vec2FeedForward(config) - self.final_layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) + self.final_layer_norm = nn.LayerNorm( + config.hidden_size, epsilon=config.layer_norm_eps) - def forward(self, hidden_states, attention_mask=None, output_attentions=False): + def forward(self, + hidden_states, + attention_mask=None, + output_attentions=False): attn_residual = hidden_states hidden_states = self.layer_norm(hidden_states) hidden_states, attn_weights, _ = self.attention( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions) hidden_states = self.dropout(hidden_states) hidden_states = attn_residual + hidden_states - hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) + hidden_states = hidden_states + self.feed_forward( + self.final_layer_norm(hidden_states)) - outputs = (hidden_states,) + outputs = (hidden_states, ) if output_attentions: - outputs += (attn_weights,) + outputs += (attn_weights, ) return outputs @@ -632,33 +655,38 @@ class Wav2Vec2Encoder(nn.Layer): super().__init__() self.config = config self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config) - self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) + self.layer_norm = nn.LayerNorm( + config.hidden_size, epsilon=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout) - self.layers = nn.LayerList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.LayerList([ + Wav2Vec2EncoderLayer(config) + for _ in range(config.num_hidden_layers) + ]) self.gradient_checkpointing = False def forward( - self, - hidden_states, - attention_mask=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - ): + self, + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None if attention_mask is not None: # make sure padded tokens output 0 - expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + expand_attention_mask = attention_mask.unsqueeze(-1).repeat( + 1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 # extend attention_mask - attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = 1.0 - attention_mask[:, None, None, :].to( + dtype=hidden_states.dtype) attention_mask = attention_mask * np.iinfo(np.float32).min - attention_mask = attention_mask.expand( - attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] - ) + attention_mask = attention_mask.expand(attention_mask.shape[0], 1, + attention_mask.shape[-1], + attention_mask.shape[-1]) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings @@ -669,13 +697,14 @@ class Wav2Vec2Encoder(nn.Layer): for layer in self.layers: if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + all_hidden_states = all_hidden_states + (hidden_states, ) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = np.random.uniform(0, 1) - skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False - if not skip_the_layer:# or deepspeed_zero3_is_enabled: + skip_the_layer = True if self.training and ( + dropout_probability < self.config.layerdrop) else False + if not skip_the_layer: # or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync if self.gradient_checkpointing and self.training: # create gradient checkpointing function @@ -686,26 +715,30 @@ class Wav2Vec2Encoder(nn.Layer): return custom_forward else: layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions) hidden_states = layer_outputs[0] if skip_the_layer: layer_outputs = (None, None) if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_self_attentions = all_self_attentions + (layer_outputs[1], ) if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + all_hidden_states = all_hidden_states + (hidden_states, ) if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return tuple( + v + for v in + [hidden_states, all_hidden_states, all_self_attentions] + if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) + attentions=all_self_attentions, ) class Wav2Vec2EncoderStableLayerNorm(nn.Layer): @@ -713,35 +746,39 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Layer): super().__init__() self.config = config self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config) - self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) + self.layer_norm = nn.LayerNorm( + config.hidden_size, epsilon=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout) - self.layers = nn.LayerList( - [Wav2Vec2EncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] - ) + self.layers = nn.LayerList([ + Wav2Vec2EncoderLayerStableLayerNorm(config) + for _ in range(config.num_hidden_layers) + ]) self.gradient_checkpointing = False def forward( - self, - hidden_states, - attention_mask=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - ): + self, + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None if attention_mask is not None: # make sure padded tokens are not attended to - expand_attention_mask = attention_mask.unsqueeze(-1).repeat_interleave(hidden_states.shape[2], axis=2) + expand_attention_mask = attention_mask.unsqueeze( + -1).repeat_interleave( + hidden_states.shape[2], axis=2) hidden_states[~expand_attention_mask] = 0 # extend attention_mask - attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = 1.0 - attention_mask[:, None, None, :].to( + dtype=hidden_states.dtype) attention_mask = attention_mask * np.iinfo(np.float32).min - attention_mask = attention_mask.expand( - attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] - ) + attention_mask = attention_mask.expand(attention_mask.shape[0], 1, + attention_mask.shape[-1], + attention_mask.shape[-1]) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings @@ -749,13 +786,14 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Layer): for layer in self.layers: if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + all_hidden_states = all_hidden_states + (hidden_states, ) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = np.random.uniform(0, 1) - skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False - if not skip_the_layer:# or deepspeed_zero3_is_enabled: + skip_the_layer = True if self.training and ( + dropout_probability < self.config.layerdrop) else False + if not skip_the_layer: # or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: @@ -767,28 +805,32 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Layer): return custom_forward else: layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions) hidden_states = layer_outputs[0] if skip_the_layer: layer_outputs = (None, None) if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_self_attentions = all_self_attentions + (layer_outputs[1], ) hidden_states = self.layer_norm(hidden_states) if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + all_hidden_states = all_hidden_states + (hidden_states, ) if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return tuple( + v + for v in + [hidden_states, all_hidden_states, all_self_attentions] + if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) + attentions=all_self_attentions, ) class Wav2Vec2GumbelVectorQuantizer(nn.Layer): @@ -810,9 +852,13 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Layer): # storage for codebook variables (codewords) self.codevectors = paddle.static.create_parameter( - shape=[1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups], dtype='float32' - ) - self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars) + shape=[ + 1, self.num_groups * self.num_vars, + config.codevector_dim // self.num_groups + ], + dtype='float32') + self.weight_proj = nn.Linear(config.conv_dim[-1], + self.num_groups * self.num_vars) # can be decayed for training self.temperature = 2 @@ -826,7 +872,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Layer): else: marginal_probs = probs.mean(dim=0) - perplexity = paddle.exp(-paddle.sum(marginal_probs * paddle.log(marginal_probs + 1e-7), dim=-1)).sum() + perplexity = paddle.exp(-paddle.sum( + marginal_probs * paddle.log(marginal_probs + 1e-7), dim=-1)).sum() return perplexity def forward(self, hidden_states, mask_time_indices=None): @@ -834,35 +881,45 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Layer): # project to codevector dim hidden_states = self.weight_proj(hidden_states) - hidden_states = hidden_states.reshape((batch_size * sequence_length * self.num_groups, -1)) + hidden_states = hidden_states.reshape( + (batch_size * sequence_length * self.num_groups, -1)) if self.training: # sample code vector probs via gumbel in differentiateable way codevector_probs = nn.functional.gumbel_softmax( - hidden_states.float(), tau=self.temperature, hard=True - ).type_as(hidden_states) + hidden_states.float(), tau=self.temperature, + hard=True).type_as(hidden_states) # compute perplexity codevector_soft_dist = paddle.softmax( - hidden_states.reshape((batch_size * sequence_length, self.num_groups, -1)).float(), axis=-1 - ) - perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices) + hidden_states.reshape((batch_size * sequence_length, + self.num_groups, -1)).float(), + axis=-1) + perplexity = self._compute_perplexity(codevector_soft_dist, + mask_time_indices) else: # take argmax in non-differentiable way # comptute hard codevector distribution (one hot) codevector_idx = hidden_states.argmax(dim=-1) - codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_( - -1, codevector_idx.reshape((-1, 1)), 1.0 - ) - codevector_probs = codevector_probs.reshape((batch_size * sequence_length, self.num_groups, -1)) - - perplexity = self._compute_perplexity(codevector_probs, mask_time_indices) - - codevector_probs = codevector_probs.reshape((batch_size * sequence_length, -1)) + codevector_probs = hidden_states.new_zeros( + *hidden_states.shape).scatter_(-1, + codevector_idx.reshape((-1, 1)), + 1.0) + codevector_probs = codevector_probs.reshape( + (batch_size * sequence_length, self.num_groups, -1)) + + perplexity = self._compute_perplexity(codevector_probs, + mask_time_indices) + + codevector_probs = codevector_probs.reshape( + (batch_size * sequence_length, -1)) # use probs to retrieve codevectors - codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors - codevectors = codevectors_per_group.reshape((batch_size * sequence_length, self.num_groups, self.num_vars, -1)) - codevectors = codevectors.sum(-2).reshape((batch_size, sequence_length, -1)) + codevectors_per_group = codevector_probs.unsqueeze( + -1) * self.codevectors + codevectors = codevectors_per_group.reshape( + (batch_size * sequence_length, self.num_groups, self.num_vars, -1)) + codevectors = codevectors.sum(-2).reshape( + (batch_size, sequence_length, -1)) return codevectors, perplexity @@ -878,7 +935,9 @@ class Wav2Vec2Adapter(nn.Layer): else: self.proj = self.proj_layer_norm = None - self.layers = nn.LayerList(Wav2Vec2AdapterLayer(config) for _ in range(config.num_adapter_layers)) + self.layers = nn.LayerList( + Wav2Vec2AdapterLayer(config) + for _ in range(config.num_adapter_layers)) self.layerdrop = config.layerdrop def forward(self, hidden_states): @@ -906,8 +965,7 @@ class Wav2Vec2AdapterLayer(nn.Layer): 2 * config.output_hidden_size, config.adapter_kernel_size, stride=config.adapter_stride, - padding=1, - ) + padding=1, ) def forward(self, hidden_states): hidden_states = self.conv(hidden_states) @@ -916,7 +974,7 @@ class Wav2Vec2AdapterLayer(nn.Layer): return hidden_states -class Wav2Vec2Model(nn.Layer): +class Wav2Vec2Model(nn.Layer): def __init__(self, config): super().__init__() self.config = config @@ -925,9 +983,13 @@ class Wav2Vec2Model(nn.Layer): # model only needs masking vector if mask prob is > 0.0 if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: - # self.masked_spec_embed = nn.Parameter(paddle.Tensor(config.hidden_size).uniform_()) + # self.masked_spec_embed = nn.Parameter(paddle.Tensor(config.hidden_size).uniform_()) #self.masked_spec_embed = paddle.uniform([config.hidden_size]) - self.masked_spec_embed = paddle.static.create_parameter(shape=[config.hidden_size], dtype='float32', default_initializer=paddle.nn.initializer.Uniform(low=0, high=1.0)) + self.masked_spec_embed = paddle.static.create_parameter( + shape=[config.hidden_size], + dtype='float32', + default_initializer=paddle.nn.initializer.Uniform( + low=0, high=1.0)) if config.do_stable_layer_norm: self.encoder = Wav2Vec2EncoderStableLayerNorm(config) else: @@ -946,11 +1008,10 @@ class Wav2Vec2Model(nn.Layer): self.feature_extractor._freeze_parameters() def _mask_hidden_states( - self, - hidden_states: paddle.Tensor, - mask_time_indices: Optional[paddle.Tensor] = None, - attention_mask: Optional[paddle.Tensor] = None, - ): + self, + hidden_states: paddle.Tensor, + mask_time_indices: Optional[paddle.Tensor]=None, + attention_mask: Optional[paddle.Tensor]=None, ): """ Masks extracted features along time axis and/or along feature axis according to [SpecAugment](https://arxiv.org/abs/1904.08779). @@ -963,17 +1024,19 @@ class Wav2Vec2Model(nn.Layer): batch_size, sequence_length, hidden_size = hidden_states.shape if mask_time_indices is not None: # apply SpecAugment along time axis with given mask_time_indices - hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + hidden_states[mask_time_indices] = self.masked_spec_embed.to( + hidden_states.dtype) elif self.config.mask_time_prob > 0 and self.training: mask_time_indices = _compute_mask_indices( (batch_size, sequence_length), mask_prob=self.config.mask_time_prob, mask_length=self.config.mask_time_length, attention_mask=attention_mask, - min_masks=self.config.mask_time_min_masks, - ) - mask_time_indices = paddle.to_tensor(mask_time_indices, dtype=paddle.bool) - hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + min_masks=self.config.mask_time_min_masks, ) + mask_time_indices = paddle.to_tensor( + mask_time_indices, dtype=paddle.bool) + hidden_states[mask_time_indices] = self.masked_spec_embed.to( + hidden_states.dtype) if self.config.mask_feature_prob > 0 and self.training: # generate indices & apply SpecAugment along feature axis @@ -981,27 +1044,28 @@ class Wav2Vec2Model(nn.Layer): (batch_size, hidden_size), mask_prob=self.config.mask_feature_prob, mask_length=self.config.mask_feature_length, - min_masks=self.config.mask_feature_min_masks, - ) - mask_feature_indices = paddle.to_tensor(mask_feature_indices, dtype=paddle.bool) - mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) + min_masks=self.config.mask_feature_min_masks, ) + mask_feature_indices = paddle.to_tensor( + mask_feature_indices, dtype=paddle.bool) + mask_feature_indices = mask_feature_indices[:, None].expand( + -1, sequence_length, -1) hidden_states[mask_feature_indices] = 0 return hidden_states def forward( - self, - input_values: Optional[paddle.Tensor], - attention_mask: Optional[paddle.Tensor] = None, - mask_time_indices: Optional[paddle.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + self, + input_values: Optional[paddle.Tensor], + attention_mask: Optional[paddle.Tensor]=None, + mask_time_indices: Optional[paddle.Tensor]=None, + output_attentions: Optional[bool]=None, + output_hidden_states: Optional[bool]=None, + return_dict: Optional[bool]=None, ) -> Union[Tuple, Wav2Vec2BaseModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.use_return_dict extract_features = self.feature_extractor(input_values) extract_features = extract_features.transpose([0, 2, 1]) @@ -1009,20 +1073,20 @@ class Wav2Vec2Model(nn.Layer): if attention_mask is not None: # compute reduced attention_mask corresponding to feature vectors attention_mask = self._get_feature_vector_attention_mask( - extract_features.shape[1], attention_mask, add_adapter=False - ) - hidden_states, extract_features = self.feature_projection(extract_features) + extract_features.shape[1], attention_mask, add_adapter=False) + hidden_states, extract_features = self.feature_projection( + extract_features) hidden_states = self._mask_hidden_states( - hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask - ) + hidden_states, + mask_time_indices=mask_time_indices, + attention_mask=attention_mask) encoder_outputs = self.encoder( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) + return_dict=return_dict, ) hidden_states = encoder_outputs[0] @@ -1036,20 +1100,21 @@ class Wav2Vec2Model(nn.Layer): last_hidden_state=hidden_states, extract_features=extract_features, hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) + attentions=encoder_outputs.attentions, ) def post_init(self): """ A method executed at the end of each Transformer model initialization, to execute code that needs the model's modules properly initialized (such as weight initialization). """ - # self.init_weights() - # self._backward_compatibility_gradient_checkpointing() + # self.init_weights() + # self._backward_compatibility_gradient_checkpointing() pass + class Wav2Vec2ConfigPure(): model_type = "wav2vec2" + def __init__(self, config): self.output_attentions = False self.output_hidden_states = False @@ -1084,17 +1149,14 @@ class Wav2Vec2ConfigPure(): self.do_stable_layer_norm = config.do_stable_layer_norm self.use_weighted_layer_sum = config.use_weighted_layer_sum - if ( - (len(self.conv_stride) != self.num_feat_extract_layers) - or (len(self.conv_kernel) != self.num_feat_extract_layers) - or (len(self.conv_dim) != self.num_feat_extract_layers) - ): + if ((len(self.conv_stride) != self.num_feat_extract_layers) or + (len(self.conv_kernel) != self.num_feat_extract_layers) or + (len(self.conv_dim) != self.num_feat_extract_layers)): raise ValueError( "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` ==" " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) =" f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`," - f" `len(config.conv_kernel) = {len(self.conv_kernel)}`." - ) + f" `len(config.conv_kernel) = {len(self.conv_kernel)}`.") # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779 self.apply_spec_augment = config.apply_spec_augment diff --git a/paddlespeech/s2t/models/wav2vec2/processing/signal_processing.py b/paddlespeech/s2t/models/wav2vec2/processing/signal_processing.py index 8eb9b4ad..9998a8e5 100644 --- a/paddlespeech/s2t/models/wav2vec2/processing/signal_processing.py +++ b/paddlespeech/s2t/models/wav2vec2/processing/signal_processing.py @@ -7,10 +7,8 @@ Authors * Samuele Cornell 2020 * Sarthak Yadav 2022 """ -import paddle -import math -from packaging import version import numpy as np +import paddle def blackman_window(window_length, periodic=True): @@ -90,15 +88,14 @@ def compute_amplitude(waveforms, lengths=None, amp_type="avg", scale="linear"): def convolve1d( - waveform, - kernel, - padding=0, - pad_type="constant", - stride=1, - groups=1, - use_fft=False, - rotation_index=0, -): + waveform, + kernel, + padding=0, + pad_type="constant", + stride=1, + groups=1, + use_fft=False, + rotation_index=0, ): """Use paddle.nn.functional to perform 1d padding and conv. Arguments --------- @@ -150,8 +147,7 @@ def convolve1d( # Padding can be a tuple (left_pad, right_pad) or an int if isinstance(padding, tuple): waveform = paddle.nn.functional.pad( - x=waveform, pad=padding, mode=pad_type, data_format='NCL' - ) + x=waveform, pad=padding, mode=pad_type, data_format='NCL') # This approach uses FFT, which is more efficient if the kernel is large if use_fft: @@ -165,9 +161,7 @@ def convolve1d( # Perform rotation to ensure alignment zeros = paddle.zeros( - [kernel.shape[0], kernel.shape[1], zero_length], - dtype=kernel.dtype - ) + [kernel.shape[0], kernel.shape[1], zero_length], dtype=kernel.dtype) after_index = kernel[..., rotation_index:] before_index = kernel[..., :rotation_index] kernel = paddle.concat((after_index, zeros, before_index), axis=-1) @@ -185,12 +179,12 @@ def convolve1d( weight=kernel, stride=stride, groups=groups, - padding=padding if not isinstance(padding, tuple) else 0, - ) + padding=padding if not isinstance(padding, tuple) else 0, ) # Return time dimension to the second dimension. return convolved.transpose([0, 2, 1]) + def notch_filter(notch_freq, filter_width=101, notch_width=0.05): """Returns a notch filter constructed from a high-pass and low-pass filter. (from https://tomroelandts.com/articles/ @@ -224,7 +218,8 @@ def notch_filter(notch_freq, filter_width=101, notch_width=0.05): return paddle.sin(x) / x # The zero is at the middle index - return paddle.concat([_sinc(x[:pad]), paddle.ones([1]), _sinc(x[pad + 1 :])]) + return paddle.concat( + [_sinc(x[:pad]), paddle.ones([1]), _sinc(x[pad + 1:])]) # Compute a low-pass filter with cutoff frequency notch_freq. hlpf = sinc(3 * (notch_freq - notch_width) * inputs) @@ -239,4 +234,3 @@ def notch_filter(notch_freq, filter_width=101, notch_width=0.05): # Adding filters creates notch filter return (hlpf + hhpf).view(1, -1, 1) - diff --git a/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py b/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py index f67121ed..471ab765 100644 --- a/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py +++ b/paddlespeech/s2t/models/wav2vec2/processing/speech_augmentation.py @@ -1,11 +1,12 @@ import math + import paddle import paddle.nn as nn -import paddle.nn.functional as F -from paddlespeech.s2t.models.wav2vec2.processing.signal_processing import ( - compute_amplitude, - convolve1d, - notch_filter) + +from paddlespeech.s2t.models.wav2vec2.processing.signal_processing import compute_amplitude +from paddlespeech.s2t.models.wav2vec2.processing.signal_processing import convolve1d +from paddlespeech.s2t.models.wav2vec2.processing.signal_processing import notch_filter + class SpeedPerturb(nn.Layer): """Slightly speed up or slow down an audio signal. @@ -36,8 +37,10 @@ class SpeedPerturb(nn.Layer): """ def __init__( - self, orig_freq, speeds=[90, 100, 110], perturb_prob=1.0, - ): + self, + orig_freq, + speeds=[90, 100, 110], + perturb_prob=1.0, ): super().__init__() self.orig_freq = orig_freq self.speeds = speeds @@ -70,14 +73,15 @@ class SpeedPerturb(nn.Layer): # Don't perturb (return early) 1-`perturb_prob` portion of the batches if paddle.rand([1]) > self.perturb_prob: - + return waveform.clone() # Perform a random perturbation - self.samp_index = paddle.randint(len(self.speeds), shape=(1,))[0] + self.samp_index = paddle.randint(len(self.speeds), shape=(1, ))[0] perturbed_waveform = self.resamplers[self.samp_index](waveform) return perturbed_waveform + class Resample(nn.Layer): """This class resamples an audio signal using sinc-based interpolation. @@ -94,9 +98,12 @@ class Resample(nn.Layer): Controls the sharpness of the filter, larger numbers result in a sharper filter, but they are less efficient. Values from 4 to 10 are allowed. """ + def __init__( - self, orig_freq=16000, new_freq=16000, lowpass_filter_width=6, - ): + self, + orig_freq=16000, + new_freq=16000, + lowpass_filter_width=6, ): super().__init__() self.orig_freq = orig_freq self.new_freq = new_freq @@ -193,8 +200,7 @@ class Resample(nn.Layer): window_size = self.weights.shape[1] tot_output_samp = self._output_samples(wave_len) resampled_waveform = paddle.zeros( - (batch_size, num_channels, tot_output_samp) - ) + (batch_size, num_channels, tot_output_samp)) # self.weights = self.weights.to(waveforms.device) # Check weights are on correct device @@ -222,28 +228,25 @@ class Resample(nn.Layer): right_padding = max(0, end_index + 1 - current_wave_len) left_padding = max(0, -first_index) wave_to_conv = paddle.nn.functional.pad( - wave_to_conv, (left_padding, right_padding), data_format='NCL' - ) + wave_to_conv, (left_padding, right_padding), data_format='NCL') conv_wave = paddle.nn.functional.conv1d( x=wave_to_conv, weight=self.weights[i].repeat(num_channels, 1, 1), stride=self.conv_stride, - groups=num_channels, - ) + groups=num_channels, ) # we want conv_wave[:, i] to be at # output[:, i + n*conv_transpose_stride] dilated_conv_wave = paddle.nn.functional.conv1d_transpose( - conv_wave, eye, stride=self.conv_transpose_stride - ) + conv_wave, eye, stride=self.conv_transpose_stride) # pad dilated_conv_wave so it reaches the output length if needed. left_padding = i previous_padding = left_padding + dilated_conv_wave.shape[-1] right_padding = max(0, tot_output_samp - previous_padding) dilated_conv_wave = paddle.nn.functional.pad( - dilated_conv_wave, (left_padding, right_padding), data_format='NCL' - ) + dilated_conv_wave, (left_padding, right_padding), + data_format='NCL') dilated_conv_wave = dilated_conv_wave[..., :tot_output_samp] resampled_waveform += dilated_conv_wave @@ -326,9 +329,7 @@ class Resample(nn.Layer): window_width = self.lowpass_filter_width / (2.0 * lowpass_cutoff) assert lowpass_cutoff < min(self.orig_freq, self.new_freq) / 2 - output_t = paddle.arange( - start=0.0, end=self.output_samples - ) + output_t = paddle.arange(start=0.0, end=self.output_samples) output_t /= self.new_freq min_t = output_t - window_width max_t = output_t + window_width @@ -346,23 +347,16 @@ class Resample(nn.Layer): inside_window_indices = delta_t.abs() < (window_width) # raised-cosine (Hanning) window with width `window_width` - weights[inside_window_indices] = 0.5 * ( - 1 - + paddle.cos( - 2 - * math.pi - * lowpass_cutoff - / self.lowpass_filter_width - * delta_t[inside_window_indices] - ) - ) + weights[inside_window_indices] = 0.5 * (1 + paddle.cos( + 2 * math.pi * lowpass_cutoff / self.lowpass_filter_width * + delta_t[inside_window_indices])) t_eq_zero_indices = delta_t == 0.0 t_not_eq_zero_indices = ~t_eq_zero_indices # sinc filter function weights[t_not_eq_zero_indices] *= paddle.sin( - 2 * math.pi * lowpass_cutoff * delta_t[t_not_eq_zero_indices] - ) / (math.pi * delta_t[t_not_eq_zero_indices]) + 2 * math.pi * lowpass_cutoff * delta_t[t_not_eq_zero_indices]) / ( + math.pi * delta_t[t_not_eq_zero_indices]) # limit of the function at t = 0 weights[t_eq_zero_indices] *= 2 * lowpass_cutoff @@ -405,14 +399,13 @@ class DropFreq(nn.Layer): """ def __init__( - self, - drop_freq_low=1e-14, - drop_freq_high=1, - drop_count_low=1, - drop_count_high=2, - drop_width=0.05, - drop_prob=1, - ): + self, + drop_freq_low=1e-14, + drop_freq_high=1, + drop_count_low=1, + drop_count_high=2, + drop_width=0.05, + drop_prob=1, ): super().__init__() self.drop_freq_low = drop_freq_low self.drop_freq_high = drop_freq_high @@ -443,14 +436,14 @@ class DropFreq(nn.Layer): # Pick number of frequencies to drop drop_count = paddle.randint( - low=self.drop_count_low, high=self.drop_count_high + 1, shape=(1,), - ) + low=self.drop_count_low, + high=self.drop_count_high + 1, + shape=(1, ), ) # Pick a frequency to drop drop_range = self.drop_freq_high - self.drop_freq_low drop_frequency = ( - paddle.rand(drop_count) * drop_range + self.drop_freq_low - ) + paddle.rand(drop_count) * drop_range + self.drop_freq_low) # Filter parameters filter_length = 101 pad = filter_length // 2 @@ -461,8 +454,9 @@ class DropFreq(nn.Layer): # Subtract each frequency for frequency in drop_frequency: notch_kernel = notch_filter( - frequency, filter_length, self.drop_width, - ) + frequency, + filter_length, + self.drop_width, ) drop_filter = convolve1d(drop_filter, notch_kernel, pad) # Apply filter @@ -471,6 +465,7 @@ class DropFreq(nn.Layer): # Remove channels dimension if added return dropped_waveform.squeeze(-1) + class DropChunk(nn.Layer): """This class drops portions of the input signal. Using `DropChunk` as an augmentation strategy helps a models learn to rely @@ -515,16 +510,15 @@ class DropChunk(nn.Layer): """ def __init__( - self, - drop_length_low=100, - drop_length_high=1000, - drop_count_low=1, - drop_count_high=10, - drop_start=0, - drop_end=None, - drop_prob=1, - noise_factor=0.0, - ): + self, + drop_length_low=100, + drop_length_high=1000, + drop_count_low=1, + drop_count_high=10, + drop_start=0, + drop_end=None, + drop_prob=1, + noise_factor=0.0, ): super().__init__() self.drop_length_low = drop_length_low self.drop_length_high = drop_length_high @@ -580,8 +574,7 @@ class DropChunk(nn.Layer): drop_times = paddle.randint( low=self.drop_count_low, high=self.drop_count_high + 1, - shape=(batch_size,), - ) + shape=(batch_size, ), ) # Iterate batch to set mask for i in range(batch_size): @@ -592,8 +585,7 @@ class DropChunk(nn.Layer): length = paddle.randint( low=self.drop_length_low, high=self.drop_length_high + 1, - shape=(drop_times[i],), - ) + shape=(drop_times[i], ), ) # Compute range of starting locations start_min = self.drop_start @@ -608,15 +600,16 @@ class DropChunk(nn.Layer): # Pick starting locations start = paddle.randint( - low=start_min, high=start_max + 1, shape=(drop_times[i],), - ) + low=start_min, + high=start_max + 1, + shape=(drop_times[i], ), ) end = start + length # Update waveform if not self.noise_factor: for j in range(drop_times[i]): - dropped_waveform[i, start[j] : end[j]] = 0.0 + dropped_waveform[i, start[j]:end[j]] = 0.0 else: # Uniform distribution of -2 to +2 * avg amplitude should # preserve the average for normalization @@ -625,7 +618,7 @@ class DropChunk(nn.Layer): # zero-center the noise distribution noise_vec = paddle.rand([length[j]]) noise_vec = 2 * noise_max * noise_vec - noise_max - dropped_waveform[i, start[j] : end[j]] = noise_vec + dropped_waveform[i, start[j]:end[j]] = noise_vec return dropped_waveform @@ -679,37 +672,33 @@ class TimeDomainSpecAugment(nn.Layer): """ def __init__( - self, - perturb_prob=1.0, - drop_freq_prob=1.0, - drop_chunk_prob=1.0, - speeds=[95, 100, 105], - sample_rate=16000, - drop_freq_count_low=0, - drop_freq_count_high=3, - drop_chunk_count_low=0, - drop_chunk_count_high=5, - drop_chunk_length_low=1000, - drop_chunk_length_high=2000, - drop_chunk_noise_factor=0, - ): + self, + perturb_prob=1.0, + drop_freq_prob=1.0, + drop_chunk_prob=1.0, + speeds=[95, 100, 105], + sample_rate=16000, + drop_freq_count_low=0, + drop_freq_count_high=3, + drop_chunk_count_low=0, + drop_chunk_count_high=5, + drop_chunk_length_low=1000, + drop_chunk_length_high=2000, + drop_chunk_noise_factor=0, ): super().__init__() self.speed_perturb = SpeedPerturb( - perturb_prob=perturb_prob, orig_freq=sample_rate, speeds=speeds - ) + perturb_prob=perturb_prob, orig_freq=sample_rate, speeds=speeds) self.drop_freq = DropFreq( drop_prob=drop_freq_prob, drop_count_low=drop_freq_count_low, - drop_count_high=drop_freq_count_high, - ) + drop_count_high=drop_freq_count_high, ) self.drop_chunk = DropChunk( drop_prob=drop_chunk_prob, drop_count_low=drop_chunk_count_low, drop_count_high=drop_chunk_count_high, drop_length_low=drop_chunk_length_low, drop_length_high=drop_chunk_length_high, - noise_factor=drop_chunk_noise_factor, - ) + noise_factor=drop_chunk_noise_factor, ) def forward(self, waveforms, lengths): """Returns the distorted waveforms. @@ -724,4 +713,4 @@ class TimeDomainSpecAugment(nn.Layer): waveforms = self.speed_perturb(waveforms) waveforms = self.drop_freq(waveforms) waveforms = self.drop_chunk(waveforms, lengths) - return waveforms \ No newline at end of file + return waveforms diff --git a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py index 6c8b0ee4..f54748f8 100644 --- a/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py +++ b/paddlespeech/s2t/models/wav2vec2/wav2vec2_ASR.py @@ -1,30 +1,24 @@ -import numpy as np -import os - +from collections import defaultdict from typing import Dict from typing import List -from typing import Optional from typing import Tuple import paddle import paddle.nn as nn import paddle.nn.functional as F + from paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2 import Wav2Vec2ConfigPure from paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2 import Wav2Vec2Model -from paddlespeech.s2t.modules.mask import make_pad_mask -from paddlespeech.s2t.utils.utility import log_add - -from collections import defaultdict - from paddlespeech.s2t.models.wav2vec2.modules.VanillaNN import VanillaNN from paddlespeech.s2t.modules.ctc import CTCDecoderBase as CTC from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank -from yacs.config import CfgNode +from paddlespeech.s2t.utils.utility import log_add + class Wav2vec2ASR(nn.Layer): def __init__(self, config: dict): super().__init__() - + wav2vec2_config = Wav2Vec2ConfigPure(config) wav2vec2 = Wav2Vec2Model(wav2vec2_config) model_dict = paddle.load(config.wav2vec2_params_path) @@ -36,8 +30,16 @@ class Wav2vec2ASR(nn.Layer): for parm in wav2vec2.parameters(): parm.trainable = False self.wav2vec2 = wav2vec2 - self.enc = VanillaNN(input_shape=[None,None,wav2vec2_config.hidden_size], activation=nn.LeakyReLU, dnn_blocks=config.dnn_blocks, dnn_neurons=config.dnn_neurons) - self.ctc = CTC(odim=config.output_dim, enc_n_units=config.dnn_neurons, blank_id=config.blank_id, dropout_rate=config.ctc_dropout_rate, reduction=True) + self.enc = VanillaNN( + input_shape=[None, None, wav2vec2_config.hidden_size], + activation=nn.LeakyReLU, + dnn_blocks=config.dnn_blocks, + dnn_neurons=config.dnn_neurons) + self.ctc = CTC(odim=config.output_dim, + enc_n_units=config.dnn_neurons, + blank_id=config.blank_id, + dropout_rate=config.ctc_dropout_rate, + reduction=True) def forward(self, wav, wavs_lens_rate, target, target_lens_rate): if self.normalize_wav: @@ -51,25 +53,27 @@ class Wav2vec2ASR(nn.Layer): x = self.enc(feats) x_lens = (wavs_lens_rate * x.shape[1]).round().astype(paddle.int64) - target_lens = (target_lens_rate * target.shape[1]).round().astype(paddle.int64) - + target_lens = (target_lens_rate * + target.shape[1]).round().astype(paddle.int64) + ctc_loss = self.ctc(x, x_lens, target, target_lens) return ctc_loss @paddle.no_grad() - def decode(self, + def decode(self, feats: paddle.Tensor, text_feature: Dict[str, int], decoding_method: str, beam_size: int): batch_size = feats.shape[0] - if decoding_method is 'ctc_prefix_beam_search' and batch_size > 1: + + if decoding_method == 'ctc_prefix_beam_search' and batch_size > 1: logger.error( f'decoding mode {decoding_method} must be running with batch_size == 1' ) logger.error(f"current batch_size is {batch_size}") sys.exit(1) - + if decoding_method == 'ctc_greedy_search': hyps = self.ctc_greedy_search(feats) res = [text_feature.defeaturize(hyp) for hyp in hyps] @@ -79,13 +83,12 @@ class Wav2vec2ASR(nn.Layer): # with other batch decoding mode elif decoding_method == 'ctc_prefix_beam_search': assert feats.shape[0] == 1 - hyp = self.ctc_prefix_beam_search( - feats, - beam_size) + hyp = self.ctc_prefix_beam_search(feats, beam_size) res = [text_feature.defeaturize(hyp)] res_tokenids = [hyp] else: - raise ValueError(f"wav2vec2 not support decoding method: {decoding_method}") + raise ValueError( + f"wav2vec2 not support decoding method: {decoding_method}") return res, res_tokenids @@ -94,8 +97,7 @@ class Wav2vec2ASR(nn.Layer): model = cls(config) return model - def ctc_greedy_search( - self, wav) -> List[List[int]]: + def ctc_greedy_search(self, wav) -> List[List[int]]: """ Apply CTC greedy search Args: speech (paddle.Tensor): (batch, max_len) @@ -104,7 +106,7 @@ class Wav2vec2ASR(nn.Layer): List[List[int]]: best path result """ batch_size = wav.shape[0] - wav = wav[:,:,0] + wav = wav[:, :, 0] if self.normalize_wav: wav = F.layer_norm(wav, wav.shape[1:]) # Extract wav2vec output @@ -124,7 +126,10 @@ class Wav2vec2ASR(nn.Layer): return hyps def _ctc_prefix_beam_search( - self, wav, beam_size, blank_id: int=0, ) -> Tuple[List[Tuple[int, float]], paddle.Tensor]: + self, + wav, + beam_size, + blank_id: int=0, ) -> Tuple[List[Tuple[int, float]], paddle.Tensor]: """ CTC prefix beam search inner implementation Args: speech (paddle.Tensor): (batch, max_len, feat_dim) @@ -142,7 +147,7 @@ class Wav2vec2ASR(nn.Layer): paddle.Tensor: encoder output, (1, max_len, encoder_dim), it will be used for rescoring in attention rescoring mode """ - wav = wav[:,:,0] + wav = wav[:, :, 0] if self.normalize_wav: wav = F.layer_norm(wav, wav.shape[1:]) @@ -219,29 +224,5 @@ class Wav2vec2ASR(nn.Layer): Returns: List[int]: CTC prefix beam search nbest results """ - hyps = self._ctc_prefix_beam_search( - wav, beam_size) + hyps = self._ctc_prefix_beam_search(wav, beam_size) return hyps[0][0] - - # @jit.to_static - # def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: - # """ Export interface for c++ call, apply linear transform and log - # softmax before ctc - # Args: - # xs (paddle.Tensor): encoder output, (B, T, D) - # Returns: - # paddle.Tensor: activation before ctc - # """ - # return self.ctc.log_softmax(xs) - - - # def _get_data(self): - # data_dir = "data" - # wavs = np.load(os.path.join(data_dir, "wavs.npy")) - # wavs_lens = np.load(os.path.join(data_dir, "wavs_lens.npy")) - # tokens = np.load(os.path.join(data_dir, "tokens.npy")) - # tokens_lens = np.load(os.path.join(data_dir, "tokens_lens.npy")) - - # batch = (paddle.to_tensor(wavs), paddle.to_tensor(wavs_lens, dtype='float32'), - # paddle.to_tensor(tokens, dtype='int32'), paddle.to_tensor(tokens_lens, dtype='float32')) - # return batch