format wav2vec2 demo

pull/2518/head
tianhao zhang 2 years ago
parent 6e429f0513
commit 19180d359d

@ -33,7 +33,7 @@ filename =
# Specify a list of codes to ignore. # Specify a list of codes to ignore.
ignore = ignore =
W503 W503
E252,E262,E127,E265,E126,E266,E241,E261,E128,E125 E252,E262,E127,E265,E126,E266,E241,E261,E128,E125,E129
W291,W293,W605 W291,W293,W605
E203,E305,E402,E501,E721,E741,F403,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303, E203,E305,E402,E501,E721,E741,F403,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,
# shebang has extra meaning in fbcode lints, so I think it's not worth trying # shebang has extra meaning in fbcode lints, so I think it's not worth trying

@ -3,7 +3,7 @@
* asr0 - deepspeech2 Streaming/Non-Streaming * asr0 - deepspeech2 Streaming/Non-Streaming
* asr1 - transformer/conformer Streaming/Non-Streaming * asr1 - transformer/conformer Streaming/Non-Streaming
* asr2 - transformer/conformer Streaming/Non-Streaming with Kaldi feature * asr2 - transformer/conformer Streaming/Non-Streaming with Kaldi feature
* asr3 - wav2vecASR, ASR model with pre-trained wav2vec2 and CTC
## Data ## Data
| Data Subset | Duration in Seconds | | Data Subset | Duration in Seconds |

@ -382,6 +382,36 @@ class LogMelSpectrogramKaldi():
return mat return mat
class WavProcess():
def __init__(self, dither=0.1):
"""
Args:
dither (float): Dithering constant
Returns:
"""
self.dither = dither
def __call__(self, x, train):
"""
Args:
x (np.ndarray): shape (Ti,)
train (bool): True, train mode.
Raises:
ValueError: not support (Ti, C)
Returns:
np.ndarray: (T, D)
"""
dither = self.dither if train else 0.0
if x.ndim != 1:
raise ValueError("Not support x: [Time, Channel]")
waveform = np.expand_dims(x, -1)
return waveform
class LogMelSpectrogramKaldi_decay(): class LogMelSpectrogramKaldi_decay():
def __init__( def __init__(
self, self,

@ -41,6 +41,7 @@ import_alias = dict(
utterance_cmvn="paddlespeech.audio.transform.cmvn:UtteranceCMVN", utterance_cmvn="paddlespeech.audio.transform.cmvn:UtteranceCMVN",
fbank="paddlespeech.audio.transform.spectrogram:LogMelSpectrogram", fbank="paddlespeech.audio.transform.spectrogram:LogMelSpectrogram",
spectrogram="paddlespeech.audio.transform.spectrogram:Spectrogram", spectrogram="paddlespeech.audio.transform.spectrogram:Spectrogram",
wav_process="paddlespeech.audio.transform.spectrogram:WavProcess",
stft="paddlespeech.audio.transform.spectrogram:Stft", stft="paddlespeech.audio.transform.spectrogram:Stft",
istft="paddlespeech.audio.transform.spectrogram:IStft", istft="paddlespeech.audio.transform.spectrogram:IStft",
stft2fbank="paddlespeech.audio.transform.spectrogram:Stft2LogMelSpectrogram", stft2fbank="paddlespeech.audio.transform.spectrogram:Stft2LogMelSpectrogram",

@ -27,6 +27,7 @@ from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils.utility import UpdateConfig
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
class Wav2vec2Infer(): class Wav2vec2Infer():
def __init__(self, config, args): def __init__(self, config, args):
self.args = args self.args = args
@ -34,8 +35,7 @@ class Wav2vec2Infer():
self.audio_file = args.audio_file self.audio_file = args.audio_file
self.text_feature = TextFeaturizer( self.text_feature = TextFeaturizer(
unit_type=config.unit_type, unit_type=config.unit_type, vocab=config.vocab_filepath)
vocab=config.vocab_filepath)
paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu') paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu')
# model # model

@ -18,38 +18,38 @@ import time
from collections import defaultdict from collections import defaultdict
from collections import OrderedDict from collections import OrderedDict
from contextlib import nullcontext from contextlib import nullcontext
from paddlespeech.s2t.utils import mp_tools
import jsonlines import jsonlines
import numpy as np import numpy as np
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import BatchDataLoader from paddlespeech.s2t.io.dataloader import BatchDataLoader
from paddlespeech.s2t.io.dataloader import StreamDataLoader
from paddlespeech.s2t.io.dataloader import DataLoaderFactory from paddlespeech.s2t.io.dataloader import DataLoaderFactory
from paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR import Wav2vec2ASR from paddlespeech.s2t.io.dataloader import StreamDataLoader
from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugment from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugment
from paddlespeech.s2t.utils import error_rate from paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR import Wav2vec2ASR
from paddlespeech.s2t.training.optimizer import OptimizerFactory from paddlespeech.s2t.training.optimizer import OptimizerFactory
from paddlespeech.s2t.training.reporter import ObsScope from paddlespeech.s2t.training.reporter import ObsScope
from paddlespeech.s2t.training.reporter import report from paddlespeech.s2t.training.reporter import report
from paddlespeech.s2t.training.scheduler import LRSchedulerFactory from paddlespeech.s2t.training.scheduler import LRSchedulerFactory
from paddlespeech.s2t.training.timer import Timer from paddlespeech.s2t.training.timer import Timer
from paddlespeech.s2t.training.trainer import Trainer from paddlespeech.s2t.training.trainer import Trainer
from paddlespeech.s2t.utils.utility import UpdateConfig from paddlespeech.s2t.utils import error_rate
from paddlespeech.s2t.utils import layer_tools from paddlespeech.s2t.utils import layer_tools
from paddlespeech.s2t.utils import mp_tools
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import UpdateConfig
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
class Wav2Vec2ASRTrainer(Trainer): class Wav2Vec2ASRTrainer(Trainer):
def __init__(self, config, args): def __init__(self, config, args):
super().__init__(config, args) super().__init__(config, args)
self.avg_train_loss = 0 self.avg_train_loss = 0
def train_batch(self, batch_index, batch, msg): def train_batch(self, batch_index, batch, msg):
train_conf = self.config train_conf = self.config
start = time.time() start = time.time()
@ -58,7 +58,7 @@ class Wav2Vec2ASRTrainer(Trainer):
utt, wav, wavs_lens, target, target_lens = batch utt, wav, wavs_lens, target, target_lens = batch
wavs_lens_rate = wavs_lens / wav.shape[1] wavs_lens_rate = wavs_lens / wav.shape[1]
target_lens_rate = target_lens / target.shape[1] target_lens_rate = target_lens / target.shape[1]
wav = wav[:,:,0] wav = wav[:, :, 0]
wav = self.speech_augmentation(wav, wavs_lens_rate) wav = self.speech_augmentation(wav, wavs_lens_rate)
loss = self.model(wav, wavs_lens_rate, target, target_lens_rate) loss = self.model(wav, wavs_lens_rate, target, target_lens_rate)
# pring(wav, wavs_lens_rate, target, target_lens_rate) # pring(wav, wavs_lens_rate, target, target_lens_rate)
@ -108,7 +108,8 @@ class Wav2Vec2ASRTrainer(Trainer):
def valid(self): def valid(self):
self.model.eval() self.model.eval()
if not self.use_streamdata: if not self.use_streamdata:
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}") logger.info(
f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list) valid_losses = defaultdict(list)
num_seen_utts = 1 num_seen_utts = 1
total_loss = 0.0 total_loss = 0.0
@ -116,7 +117,7 @@ class Wav2Vec2ASRTrainer(Trainer):
utt, wav, wavs_lens, target, target_lens = batch utt, wav, wavs_lens, target, target_lens = batch
wavs_lens_rate = wavs_lens / wav.shape[1] wavs_lens_rate = wavs_lens / wav.shape[1]
target_lens_rate = target_lens / target.shape[1] target_lens_rate = target_lens / target.shape[1]
wav = wav[:,:,0] wav = wav[:, :, 0]
loss = self.model(wav, wavs_lens_rate, target, target_lens_rate) loss = self.model(wav, wavs_lens_rate, target, target_lens_rate)
if paddle.isfinite(loss): if paddle.isfinite(loss):
@ -134,7 +135,8 @@ class Wav2Vec2ASRTrainer(Trainer):
msg += "epoch: {}, ".format(self.epoch) msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration) msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata: if not self.use_streamdata:
msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader)) msg += "batch: {}/{}, ".format(i + 1,
len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v) msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items()) for k, v in valid_dump.items())
logger.info(msg) logger.info(msg)
@ -155,7 +157,8 @@ class Wav2Vec2ASRTrainer(Trainer):
self.before_train() self.before_train()
if not self.use_streamdata: if not self.use_streamdata:
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(
f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.n_epoch: while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"): with Timer("Epoch-Train Time Cost: {}"):
self.model.train() self.model.train()
@ -223,14 +226,18 @@ class Wav2Vec2ASRTrainer(Trainer):
config = self.config.clone() config = self.config.clone()
self.use_streamdata = config.get("use_stream_data", False) self.use_streamdata = config.get("use_stream_data", False)
if self.train: if self.train:
self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args) self.train_loader = DataLoaderFactory.get_dataloader(
self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args) 'train', config, self.args)
self.valid_loader = DataLoaderFactory.get_dataloader(
'valid', config, self.args)
logger.info("Setup train/valid Dataloader!") logger.info("Setup train/valid Dataloader!")
else: else:
decode_batch_size = config.get('decode', dict()).get( decode_batch_size = config.get('decode', dict()).get(
'decode_batch_size', 1) 'decode_batch_size', 1)
self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args) self.test_loader = DataLoaderFactory.get_dataloader('test', config,
self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args) self.args)
self.align_loader = DataLoaderFactory.get_dataloader(
'align', config, self.args)
logger.info("Setup test/align Dataloader!") logger.info("Setup test/align Dataloader!")
def setup_model(self): def setup_model(self):
@ -312,14 +319,14 @@ class Wav2Vec2ASRTester(Wav2Vec2ASRTrainer):
self.text_featurizer = TextFeaturizer( self.text_featurizer = TextFeaturizer(
unit_type=config.unit_type, vocab=config.vocab_filepath) unit_type=config.unit_type, vocab=config.vocab_filepath)
self.vocab_list = self.text_featurizer.vocab_list self.vocab_list = self.text_featurizer.vocab_list
def id2token(self, texts, texts_len): def id2token(self, texts, texts_len):
""" ord() id to chr() chr """ """ ord() id to chr() chr """
trans = [] trans = []
for text, n in zip(texts, texts_len): for text, n in zip(texts, texts_len):
n = n.numpy().item() n = n.numpy().item()
ids = text[:n] ids = text[:n]
trans.append( trans.append(self.text_featurizer.defeaturize(ids.numpy().tolist()))
self.text_featurizer.defeaturize(ids.numpy().tolist()))
return trans return trans
def compute_metrics(self, def compute_metrics(self,

@ -3,6 +3,7 @@ Authors
* Elena Rastorgueva 2020 * Elena Rastorgueva 2020
""" """
import paddle import paddle
from paddlespeech.s2t.models.wav2vec2.modules import containers from paddlespeech.s2t.models.wav2vec2.modules import containers
from paddlespeech.s2t.models.wav2vec2.modules import linear from paddlespeech.s2t.models.wav2vec2.modules import linear
@ -31,8 +32,7 @@ class VanillaNN(containers.Sequential):
input_shape, input_shape,
activation=paddle.nn.LeakyReLU, activation=paddle.nn.LeakyReLU,
dnn_blocks=2, dnn_blocks=2,
dnn_neurons=512, dnn_neurons=512, ):
):
super().__init__(input_shape=input_shape) super().__init__(input_shape=input_shape)
for block_index in range(dnn_blocks): for block_index in range(dnn_blocks):
@ -40,6 +40,5 @@ class VanillaNN(containers.Sequential):
linear.Linear, linear.Linear,
n_neurons=dnn_neurons, n_neurons=dnn_neurons,
bias=True, bias=True,
layer_name="linear", layer_name="linear", )
)
self.append(activation(), layer_name="act") self.append(activation(), layer_name="act")

@ -11,12 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
from packaging import version from paddle import nn
from paddle import Tensor, nn from paddle import Tensor
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -29,7 +27,9 @@ class NewGELUActivation(nn.Layer):
""" """
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + paddle.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * paddle.pow(input, 3.0)))) return 0.5 * input * (1.0 + paddle.tanh(
math.sqrt(2.0 / math.pi) *
(input + 0.044715 * paddle.pow(input, 3.0))))
class GELUActivation(nn.Layer): class GELUActivation(nn.Layer):
@ -40,7 +40,7 @@ class GELUActivation(nn.Layer):
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
""" """
def __init__(self, use_gelu_python: bool = False): def __init__(self, use_gelu_python: bool=False):
super().__init__() super().__init__()
self.act = nn.functional.gelu self.act = nn.functional.gelu
@ -57,7 +57,9 @@ class FastGELUActivation(nn.Layer):
""" """
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + paddle.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) return 0.5 * input * (
1.0 + paddle.tanh(input * 0.7978845608 *
(1.0 + 0.044715 * input * input)))
class QuickGELUActivation(nn.Layer): class QuickGELUActivation(nn.Layer):
@ -84,7 +86,8 @@ class ClippedGELUActivation(nn.Layer):
def __init__(self, min: float, max: float): def __init__(self, min: float, max: float):
if min > max: if min > max:
raise ValueError(f"min should be < max (got min: {min}, max: {max})") raise ValueError(
f"min should be < max (got min: {min}, max: {max})")
super().__init__() super().__init__()
self.min = min self.min = min
@ -161,7 +164,9 @@ def get_activation(activation_string):
if activation_string in ACT2FN: if activation_string in ACT2FN:
return ACT2FN[activation_string] return ACT2FN[activation_string]
else: else:
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}") raise KeyError(
f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}"
)
# For backwards compatibility with: from activations import gelu_python # For backwards compatibility with: from activations import gelu_python

@ -1,8 +1,7 @@
import paddle
import inspect import inspect
import logging
import operator import paddle
import functools
class Sequential(paddle.nn.LayerDict): class Sequential(paddle.nn.LayerDict):
"""A sequence of modules with potentially inferring shape on construction. """A sequence of modules with potentially inferring shape on construction.
@ -103,8 +102,7 @@ class Sequential(paddle.nn.LayerDict):
raise ValueError( raise ValueError(
"Must pass `input_shape` at initialization and use " "Must pass `input_shape` at initialization and use "
"modules that take `input_shape` to infer shape when " "modules that take `input_shape` to infer shape when "
"using `append()`." "using `append()`.")
)
def get_output_shape(self): def get_output_shape(self):
"""Returns expected shape of the output. """Returns expected shape of the output.

@ -3,10 +3,10 @@ Authors
* Mirco Ravanelli 2020 * Mirco Ravanelli 2020
* Davide Borra 2021 * Davide Borra 2021
""" """
import logging import logging
import paddle import paddle
import paddle.nn as nn
from paddlespeech.s2t.modules import align from paddlespeech.s2t.modules import align
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -42,8 +42,7 @@ class Linear(paddle.nn.Layer):
input_shape=None, input_shape=None,
input_size=None, input_size=None,
bias=True, bias=True,
combine_dims=False, combine_dims=False, ):
):
super().__init__() super().__init__()
self.combine_dims = combine_dims self.combine_dims = combine_dims

@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass
from dataclasses import fields from dataclasses import fields
from typing import Optional
from typing import Tuple
import paddle import paddle
@ -41,10 +41,13 @@ class ModelOutput(OrderedDict):
if not len(class_fields): if not len(class_fields):
raise ValueError(f"{self.__class__.__name__} has no fields.") raise ValueError(f"{self.__class__.__name__} has no fields.")
if not all(field.default is None for field in class_fields[1:]): if not all(field.default is None for field in class_fields[1:]):
raise ValueError(f"{self.__class__.__name__} should not have more than one required field.") raise ValueError(
f"{self.__class__.__name__} should not have more than one required field."
)
first_field = getattr(self, class_fields[0].name) first_field = getattr(self, class_fields[0].name)
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) other_fields_are_none = all(
getattr(self, field.name) is None for field in class_fields[1:])
if other_fields_are_none and not paddle.is_tensor(first_field): if other_fields_are_none and not paddle.is_tensor(first_field):
if isinstance(first_field, dict): if isinstance(first_field, dict):
@ -61,11 +64,9 @@ class ModelOutput(OrderedDict):
# set the associated fields # set the associated fields
if first_field_iterator: if first_field_iterator:
for element in iterator: for element in iterator:
if ( if (not isinstance(element, (list, tuple)) or
not isinstance(element, (list, tuple)) not len(element) == 2 or
or not len(element) == 2 not isinstance(element[0], str)):
or not isinstance(element[0], str)
):
break break
setattr(self, element[0], element[1]) setattr(self, element[0], element[1])
if element[1] is not None: if element[1] is not None:
@ -79,16 +80,23 @@ class ModelOutput(OrderedDict):
self[field.name] = v self[field.name] = v
def __delitem__(self, *args, **kwargs): def __delitem__(self, *args, **kwargs):
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") raise Exception(
f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance."
)
def setdefault(self, *args, **kwargs): def setdefault(self, *args, **kwargs):
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") raise Exception(
f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance."
)
def pop(self, *args, **kwargs): def pop(self, *args, **kwargs):
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") raise Exception(
f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
def update(self, *args, **kwargs): def update(self, *args, **kwargs):
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") raise Exception(
f"You cannot use ``update`` on a {self.__class__.__name__} instance."
)
def __getitem__(self, k): def __getitem__(self, k):
if isinstance(k, str): if isinstance(k, str):

@ -13,24 +13,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Paddle Wav2Vec2 model.""" """ Paddle Wav2Vec2 model."""
import math
import warnings
import paddle
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional
from typing import Tuple
from typing import Union
import numpy as np import numpy as np
import paddle
from paddle import nn from paddle import nn
from paddlespeech.s2t.models.wav2vec2.modules.activations import ACT2FN from paddlespeech.s2t.models.wav2vec2.modules.activations import ACT2FN
from paddlespeech.s2t.models.wav2vec2.modules.modeling_outputs import ( from paddlespeech.s2t.models.wav2vec2.modules.modeling_outputs import BaseModelOutput
BaseModelOutput, from paddlespeech.s2t.models.wav2vec2.modules.modeling_outputs import ModelOutput
Wav2Vec2BaseModelOutput, from paddlespeech.s2t.models.wav2vec2.modules.modeling_outputs import Wav2Vec2BaseModelOutput
ModelOutput
)
import pdb
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -81,9 +76,8 @@ def _compute_mask_indices(
shape: Tuple[int, int], shape: Tuple[int, int],
mask_prob: float, mask_prob: float,
mask_length: int, mask_length: int,
attention_mask: Optional[paddle.Tensor] = None, attention_mask: Optional[paddle.Tensor]=None,
min_masks: int = 0, min_masks: int=0, ) -> np.ndarray:
) -> np.ndarray:
""" """
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
@ -109,8 +103,7 @@ def _compute_mask_indices(
if mask_length > sequence_length: if mask_length > sequence_length:
raise ValueError( raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`" f" and `sequence_length`: {sequence_length}`")
)
# epsilon is used for probabilistic rounding # epsilon is used for probabilistic rounding
epsilon = np.random.rand(1).item() epsilon = np.random.rand(1).item()
@ -131,11 +124,9 @@ def _compute_mask_indices(
return num_masked_span return num_masked_span
# compute number of masked spans in batch # compute number of masked spans in batch
input_lengths = ( input_lengths = (attention_mask.sum(-1).detach().tolist()
attention_mask.sum(-1).detach().tolist() if attention_mask is not None else
if attention_mask is not None [sequence_length for _ in range(batch_size)])
else [sequence_length for _ in range(batch_size)]
)
# SpecAugment mask to fill # SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=np.bool) spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=np.bool)
@ -152,8 +143,9 @@ def _compute_mask_indices(
# get random indices to mask # get random indices to mask
spec_aug_mask_idx = np.random.choice( spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False np.arange(input_length - (mask_length - 1)),
) num_masked_span,
replace=False)
# pick first sampled index that will serve as a dummy index to pad vector # pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding # to ensure same dimension for all batches due to probabilistic rounding
@ -166,29 +158,33 @@ def _compute_mask_indices(
else: else:
dummy_mask_idx = spec_aug_mask_idx[0] dummy_mask_idx = spec_aug_mask_idx[0]
spec_aug_mask_idx = np.concatenate( spec_aug_mask_idx = np.concatenate([
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] spec_aug_mask_idx,
) np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) *
dummy_mask_idx
])
spec_aug_mask_idxs.append(spec_aug_mask_idx) spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans # expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to( spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) spec_aug_mask_idxs[:, :, None],
) (batch_size, max_num_masked_span, mask_length))
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape((batch_size, max_num_masked_span * mask_length)) spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(
(batch_size, max_num_masked_span * mask_length))
# add offset to the starting indexes so that indexes now create a span # add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :] offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( offsets = np.broadcast_to(offsets, (
(batch_size, max_num_masked_span * mask_length) batch_size, max_num_masked_span, mask_length)).reshape(
) (batch_size, max_num_masked_span * mask_length))
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# ensure that we cannot have indices larger than sequence_length # ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1: if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 spec_aug_mask_idxs[spec_aug_mask_idxs >
sequence_length - 1] = sequence_length - 1
# scatter indices to mask # scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
@ -196,9 +192,9 @@ def _compute_mask_indices(
return spec_aug_mask return spec_aug_mask
def _sample_negative_indices( def _sample_negative_indices(features_shape: Tuple,
features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None num_negatives: int,
): mask_time_indices: Optional[np.ndarray]=None):
""" """
Sample `num_negatives` vectors from feature vectors. Sample `num_negatives` vectors from feature vectors.
""" """
@ -208,23 +204,28 @@ def _sample_negative_indices(
sequence_length_range = np.arange(sequence_length) sequence_length_range = np.arange(sequence_length)
# get `num_negatives` random vector indices from the same utterance # get `num_negatives` random vector indices from the same utterance
sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32) sampled_negative_indices = np.zeros(
shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
mask_time_indices = ( mask_time_indices = (mask_time_indices.astype(np.bool)
mask_time_indices.astype(np.bool) if mask_time_indices is not None else np.ones(features_shape, dtype=np.bool) if mask_time_indices is not None else
) np.ones(features_shape, dtype=np.bool))
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
high = mask_time_indices[batch_idx].sum() - 1 high = mask_time_indices[batch_idx].sum() - 1
mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]] mapped_masked_indices = sequence_length_range[mask_time_indices[
batch_idx]]
feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives)) feature_indices = np.broadcast_to(
sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives)) np.arange(high + 1)[:, None], (high + 1, num_negatives))
sampled_indices = np.random.randint(
0, high, size=(high + 1, num_negatives))
# avoid sampling the same positive vector, but keep the distribution uniform # avoid sampling the same positive vector, but keep the distribution uniform
sampled_indices[sampled_indices >= feature_indices] += 1 sampled_indices[sampled_indices >= feature_indices] += 1
# remap to actual indices # remap to actual indices
sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices] sampled_negative_indices[batch_idx][mask_time_indices[
batch_idx]] = mapped_masked_indices[sampled_indices]
# correct for batch size # correct for batch size
sampled_negative_indices[batch_idx] += batch_idx * sequence_length sampled_negative_indices[batch_idx] += batch_idx * sequence_length
@ -243,8 +244,7 @@ class Wav2Vec2NoLayerNormConvLayer(nn.Layer):
self.out_conv_dim, self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id], kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id], stride=config.conv_stride[layer_id],
bias_attr=config.conv_bias, bias_attr=config.conv_bias, )
)
self.activation = ACT2FN[config.feat_extract_activation] self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states): def forward(self, hidden_states):
@ -264,8 +264,7 @@ class Wav2Vec2LayerNormConvLayer(nn.Layer):
self.out_conv_dim, self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id], kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id], stride=config.conv_stride[layer_id],
bias_attr=config.conv_bias, bias_attr=config.conv_bias, )
)
self.layer_norm = nn.LayerNorm(self.out_conv_dim) self.layer_norm = nn.LayerNorm(self.out_conv_dim)
self.activation = ACT2FN[config.feat_extract_activation] self.activation = ACT2FN[config.feat_extract_activation]
@ -290,11 +289,11 @@ class Wav2Vec2GroupNormConvLayer(nn.Layer):
self.out_conv_dim, self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id], kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id], stride=config.conv_stride[layer_id],
bias_attr=config.conv_bias, bias_attr=config.conv_bias, )
)
self.activation = ACT2FN[config.feat_extract_activation] self.activation = ACT2FN[config.feat_extract_activation]
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim) self.layer_norm = nn.GroupNorm(
num_groups=self.out_conv_dim, num_channels=self.out_conv_dim)
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.conv(hidden_states) hidden_states = self.conv(hidden_states)
@ -311,8 +310,7 @@ class Wav2Vec2PositionalConvEmbedding(nn.Layer):
config.hidden_size, config.hidden_size,
kernel_size=config.num_conv_pos_embeddings, kernel_size=config.num_conv_pos_embeddings,
padding=config.num_conv_pos_embeddings // 2, padding=config.num_conv_pos_embeddings // 2,
groups=config.num_conv_pos_embedding_groups, groups=config.num_conv_pos_embedding_groups, )
)
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
@ -337,7 +335,7 @@ class Wav2Vec2SamePadLayer(nn.Layer):
def forward(self, hidden_states): def forward(self, hidden_states):
if self.num_pad_remove > 0: if self.num_pad_remove > 0:
hidden_states = hidden_states[:, :, : -self.num_pad_remove] hidden_states = hidden_states[:, :, :-self.num_pad_remove]
return hidden_states return hidden_states
@ -349,11 +347,13 @@ class Wav2Vec2FeatureEncoder(nn.Layer):
if config.feat_extract_norm == "group": if config.feat_extract_norm == "group":
conv_layers = [Wav2Vec2GroupNormConvLayer(config, layer_id=0)] + [ conv_layers = [Wav2Vec2GroupNormConvLayer(config, layer_id=0)] + [
Wav2Vec2NoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1) Wav2Vec2NoLayerNormConvLayer(config, layer_id=i + 1)
for i in range(config.num_feat_extract_layers - 1)
] ]
elif config.feat_extract_norm == "layer": elif config.feat_extract_norm == "layer":
conv_layers = [ conv_layers = [
Wav2Vec2LayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers) Wav2Vec2LayerNormConvLayer(config, layer_id=i)
for i in range(config.num_feat_extract_layers)
] ]
else: else:
raise ValueError( raise ValueError(
@ -373,10 +373,12 @@ class Wav2Vec2FeatureEncoder(nn.Layer):
return hidden_states return hidden_states
class Wav2Vec2FeatureProjection(nn.Layer): class Wav2Vec2FeatureProjection(nn.Layer):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], epsilon=config.layer_norm_eps) self.layer_norm = nn.LayerNorm(
config.conv_dim[-1], epsilon=config.layer_norm_eps)
self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size) self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
self.dropout = nn.Dropout(config.feat_proj_dropout) self.dropout = nn.Dropout(config.feat_proj_dropout)
@ -396,10 +398,9 @@ class Wav2Vec2Attention(nn.Layer):
self, self,
embed_dim: int, embed_dim: int,
num_heads: int, num_heads: int,
dropout: float = 0.0, dropout: float=0.0,
is_decoder: bool = False, is_decoder: bool=False,
bias: bool = True, bias: bool=True, ):
):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
@ -409,8 +410,7 @@ class Wav2Vec2Attention(nn.Layer):
if (self.head_dim * num_heads) != self.embed_dim: if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError( raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})." f" and `num_heads`: {num_heads}).")
)
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder self.is_decoder = is_decoder
@ -420,17 +420,18 @@ class Wav2Vec2Attention(nn.Layer):
self.out_proj = nn.Linear(embed_dim, embed_dim, bias_attr=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias_attr=bias)
def _shape(self, tensor: paddle.Tensor, seq_len: int, bsz: int): def _shape(self, tensor: paddle.Tensor, seq_len: int, bsz: int):
return paddle.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)).transpose([0, 2, 1, 3]) return paddle.reshape(tensor, (bsz, seq_len, self.num_heads,
self.head_dim)).transpose([0, 2, 1, 3])
def forward( def forward(
self, self,
hidden_states: paddle.Tensor, hidden_states: paddle.Tensor,
key_value_states: Optional[paddle.Tensor] = None, key_value_states: Optional[paddle.Tensor]=None,
past_key_value: Optional[Tuple[paddle.Tensor]] = None, past_key_value: Optional[Tuple[paddle.Tensor]]=None,
attention_mask: Optional[paddle.Tensor] = None, attention_mask: Optional[paddle.Tensor]=None,
layer_head_mask: Optional[paddle.Tensor] = None, layer_head_mask: Optional[paddle.Tensor]=None,
output_attentions: bool = False, output_attentions: bool=False, ) -> Tuple[paddle.Tensor, Optional[
) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer # if key_value_states are provided this layer is used as a cross-attention layer
@ -455,7 +456,8 @@ class Wav2Vec2Attention(nn.Layer):
key_states = self._shape(self.k_proj(hidden_states), -1, bsz) key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = paddle.concat([past_key_value[0], key_states], axis=2) key_states = paddle.concat([past_key_value[0], key_states], axis=2)
value_states = paddle.concat([past_key_value[1], value_states], axis=2) value_states = paddle.concat(
[past_key_value[1], value_states], axis=2)
else: else:
# self_attention # self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz) key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
@ -472,60 +474,68 @@ class Wav2Vec2Attention(nn.Layer):
past_key_value = (key_states, value_states) past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim) proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).reshape(proj_shape) query_states = self._shape(query_states, tgt_len,
bsz).reshape(proj_shape)
key_states = key_states.reshape(proj_shape) key_states = key_states.reshape(proj_shape)
value_states = value_states.reshape(proj_shape) value_states = value_states.reshape(proj_shape)
src_len = key_states.shape[1] src_len = key_states.shape[1]
attn_weights = paddle.bmm(query_states, key_states.transpose([0, 2, 1])) attn_weights = paddle.bmm(query_states, key_states.transpose([0, 2, 1]))
if attn_weights.shape != [bsz * self.num_heads, tgt_len, src_len]: if attn_weights.shape != [bsz * self.num_heads, tgt_len, src_len]:
raise ValueError( raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.shape}" f" {attn_weights.shape}")
)
if attention_mask is not None: if attention_mask is not None:
if attention_mask.shape != [bsz, 1, tgt_len, src_len]: if attention_mask.shape != [bsz, 1, tgt_len, src_len]:
raise ValueError( raise ValueError(
f"Attention mask should be of size {[bsz, 1, tgt_len, src_len]}, but is {attention_mask.shape}" f"Attention mask should be of size {[bsz, 1, tgt_len, src_len]}, but is {attention_mask.shape}"
) )
attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len,
attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len) src_len) + attention_mask
attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len,
src_len)
attn_weights = nn.functional.softmax(attn_weights, axis=- 1) attn_weights = nn.functional.softmax(attn_weights, axis=-1)
if layer_head_mask is not None: if layer_head_mask is not None:
if layer_head_mask.shape != [self.num_heads,]: if layer_head_mask.shape != [
self.num_heads,
]:
raise ValueError( raise ValueError(
f"Head mask for a single layer should be of size {[self.num_heads,]}, but is" f"Head mask for a single layer should be of size {[self.num_heads,]}, but is"
f" {layer_head_mask.shape}" f" {layer_head_mask.shape}")
) attn_weights = layer_head_mask.reshape(
attn_weights = layer_head_mask.reshape((1, -1, 1, 1)) * attn_weights.reshape((bsz, self.num_heads, tgt_len, src_len)) (1, -1, 1, 1)) * attn_weights.reshape(
attn_weights = attn_weights.reshape((bsz * self.num_heads, tgt_len, src_len)) (bsz, self.num_heads, tgt_len, src_len))
attn_weights = attn_weights.reshape(
(bsz * self.num_heads, tgt_len, src_len))
if output_attentions: if output_attentions:
# this operation is a bit awkward, but it's required to # this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient. # make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped # In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following # twice and have to be reused in the following
attn_weights_reshaped = attn_weights.reshape((bsz, self.num_heads, tgt_len, src_len)) attn_weights_reshaped = attn_weights.reshape(
attn_weights = attn_weights_reshaped.reshape((bsz * self.num_heads, tgt_len, src_len)) (bsz, self.num_heads, tgt_len, src_len))
attn_weights = attn_weights_reshaped.reshape(
(bsz * self.num_heads, tgt_len, src_len))
else: else:
attn_weights_reshaped = None attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_probs = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training)
attn_output = paddle.bmm(attn_probs, value_states) attn_output = paddle.bmm(attn_probs, value_states)
if attn_output.shape != [bsz * self.num_heads, tgt_len, self.head_dim]: if attn_output.shape != [bsz * self.num_heads, tgt_len, self.head_dim]:
raise ValueError( raise ValueError(
f"`attn_output` should be of size {[bsz, self.num_heads, tgt_len, self.head_dim]}, but is" f"`attn_output` should be of size {[bsz, self.num_heads, tgt_len, self.head_dim]}, but is"
f" {attn_output.shape}" f" {attn_output.shape}")
)
attn_output = attn_output.reshape((bsz, self.num_heads, tgt_len, self.head_dim)) attn_output = attn_output.reshape(
(bsz, self.num_heads, tgt_len, self.head_dim))
attn_output = attn_output.transpose([0, 2, 1, 3]) attn_output = attn_output.transpose([0, 2, 1, 3])
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
@ -542,13 +552,15 @@ class Wav2Vec2FeedForward(nn.Layer):
super().__init__() super().__init__()
self.intermediate_dropout = nn.Dropout(config.activation_dropout) self.intermediate_dropout = nn.Dropout(config.activation_dropout)
self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size) self.intermediate_dense = nn.Linear(config.hidden_size,
config.intermediate_size)
if isinstance(config.hidden_act, str): if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act] self.intermediate_act_fn = ACT2FN[config.hidden_act]
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size) self.output_dense = nn.Linear(config.intermediate_size,
config.hidden_size)
self.output_dropout = nn.Dropout(config.hidden_dropout) self.output_dropout = nn.Dropout(config.hidden_dropout)
def forward(self, hidden_states): def forward(self, hidden_states):
@ -568,18 +580,23 @@ class Wav2Vec2EncoderLayer(nn.Layer):
embed_dim=config.hidden_size, embed_dim=config.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
is_decoder=False, is_decoder=False, )
)
self.dropout = nn.Dropout(config.hidden_dropout) self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) self.layer_norm = nn.LayerNorm(
config.hidden_size, epsilon=config.layer_norm_eps)
self.feed_forward = Wav2Vec2FeedForward(config) self.feed_forward = Wav2Vec2FeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) self.final_layer_norm = nn.LayerNorm(
config.hidden_size, epsilon=config.layer_norm_eps)
def forward(self, hidden_states, attention_mask=None, output_attentions=False): def forward(self,
hidden_states,
attention_mask=None,
output_attentions=False):
attn_residual = hidden_states attn_residual = hidden_states
hidden_states, attn_weights, _ = self.attention( hidden_states, attn_weights, _ = self.attention(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions hidden_states,
) attention_mask=attention_mask,
output_attentions=output_attentions)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states hidden_states = attn_residual + hidden_states
@ -587,10 +604,10 @@ class Wav2Vec2EncoderLayer(nn.Layer):
hidden_states = hidden_states + self.feed_forward(hidden_states) hidden_states = hidden_states + self.feed_forward(hidden_states)
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,) outputs = (hidden_states, )
if output_attentions: if output_attentions:
outputs += (attn_weights,) outputs += (attn_weights, )
return outputs return outputs
@ -602,27 +619,33 @@ class Wav2Vec2EncoderLayerStableLayerNorm(nn.Layer):
embed_dim=config.hidden_size, embed_dim=config.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
dropout=config.attention_dropout, dropout=config.attention_dropout,
is_decoder=False, is_decoder=False, )
)
self.dropout = nn.Dropout(config.hidden_dropout) self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) self.layer_norm = nn.LayerNorm(
config.hidden_size, epsilon=config.layer_norm_eps)
self.feed_forward = Wav2Vec2FeedForward(config) self.feed_forward = Wav2Vec2FeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) self.final_layer_norm = nn.LayerNorm(
config.hidden_size, epsilon=config.layer_norm_eps)
def forward(self, hidden_states, attention_mask=None, output_attentions=False): def forward(self,
hidden_states,
attention_mask=None,
output_attentions=False):
attn_residual = hidden_states attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.attention( hidden_states, attn_weights, _ = self.attention(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions hidden_states,
) attention_mask=attention_mask,
output_attentions=output_attentions)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states hidden_states = attn_residual + hidden_states
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) hidden_states = hidden_states + self.feed_forward(
self.final_layer_norm(hidden_states))
outputs = (hidden_states,) outputs = (hidden_states, )
if output_attentions: if output_attentions:
outputs += (attn_weights,) outputs += (attn_weights, )
return outputs return outputs
@ -632,9 +655,13 @@ class Wav2Vec2Encoder(nn.Layer):
super().__init__() super().__init__()
self.config = config self.config = config
self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config) self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config)
self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) self.layer_norm = nn.LayerNorm(
config.hidden_size, epsilon=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout) self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.LayerList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.LayerList([
Wav2Vec2EncoderLayer(config)
for _ in range(config.num_hidden_layers)
])
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward( def forward(
@ -643,22 +670,23 @@ class Wav2Vec2Encoder(nn.Layer):
attention_mask=None, attention_mask=None,
output_attentions=False, output_attentions=False,
output_hidden_states=False, output_hidden_states=False,
return_dict=True, return_dict=True, ):
):
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
if attention_mask is not None: if attention_mask is not None:
# make sure padded tokens output 0 # make sure padded tokens output 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) expand_attention_mask = attention_mask.unsqueeze(-1).repeat(
1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0 hidden_states[~expand_attention_mask] = 0
# extend attention_mask # extend attention_mask
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) attention_mask = 1.0 - attention_mask[:, None, None, :].to(
dtype=hidden_states.dtype)
attention_mask = attention_mask * np.iinfo(np.float32).min attention_mask = attention_mask * np.iinfo(np.float32).min
attention_mask = attention_mask.expand( attention_mask = attention_mask.expand(attention_mask.shape[0], 1,
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] attention_mask.shape[-1],
) attention_mask.shape[-1])
position_embeddings = self.pos_conv_embed(hidden_states) position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings hidden_states = hidden_states + position_embeddings
@ -669,13 +697,14 @@ class Wav2Vec2Encoder(nn.Layer):
for layer in self.layers: for layer in self.layers:
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states, )
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = np.random.uniform(0, 1) dropout_probability = np.random.uniform(0, 1)
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False skip_the_layer = True if self.training and (
if not skip_the_layer:# or deepspeed_zero3_is_enabled: dropout_probability < self.config.layerdrop) else False
if not skip_the_layer: # or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync # under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
# create gradient checkpointing function # create gradient checkpointing function
@ -686,26 +715,30 @@ class Wav2Vec2Encoder(nn.Layer):
return custom_forward return custom_forward
else: else:
layer_outputs = layer( layer_outputs = layer(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions hidden_states,
) attention_mask=attention_mask,
output_attentions=output_attentions)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if skip_the_layer: if skip_the_layer:
layer_outputs = (None, None) layer_outputs = (None, None)
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],) all_self_attentions = all_self_attentions + (layer_outputs[1], )
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states, )
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return tuple(
v
for v in
[hidden_states, all_hidden_states, all_self_attentions]
if v is not None)
return BaseModelOutput( return BaseModelOutput(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attentions, attentions=all_self_attentions, )
)
class Wav2Vec2EncoderStableLayerNorm(nn.Layer): class Wav2Vec2EncoderStableLayerNorm(nn.Layer):
@ -713,11 +746,13 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Layer):
super().__init__() super().__init__()
self.config = config self.config = config
self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config) self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config)
self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) self.layer_norm = nn.LayerNorm(
config.hidden_size, epsilon=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout) self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.LayerList( self.layers = nn.LayerList([
[Wav2Vec2EncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] Wav2Vec2EncoderLayerStableLayerNorm(config)
) for _ in range(config.num_hidden_layers)
])
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward( def forward(
@ -726,22 +761,24 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Layer):
attention_mask=None, attention_mask=None,
output_attentions=False, output_attentions=False,
output_hidden_states=False, output_hidden_states=False,
return_dict=True, return_dict=True, ):
):
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
if attention_mask is not None: if attention_mask is not None:
# make sure padded tokens are not attended to # make sure padded tokens are not attended to
expand_attention_mask = attention_mask.unsqueeze(-1).repeat_interleave(hidden_states.shape[2], axis=2) expand_attention_mask = attention_mask.unsqueeze(
-1).repeat_interleave(
hidden_states.shape[2], axis=2)
hidden_states[~expand_attention_mask] = 0 hidden_states[~expand_attention_mask] = 0
# extend attention_mask # extend attention_mask
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) attention_mask = 1.0 - attention_mask[:, None, None, :].to(
dtype=hidden_states.dtype)
attention_mask = attention_mask * np.iinfo(np.float32).min attention_mask = attention_mask * np.iinfo(np.float32).min
attention_mask = attention_mask.expand( attention_mask = attention_mask.expand(attention_mask.shape[0], 1,
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] attention_mask.shape[-1],
) attention_mask.shape[-1])
position_embeddings = self.pos_conv_embed(hidden_states) position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings hidden_states = hidden_states + position_embeddings
@ -749,13 +786,14 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Layer):
for layer in self.layers: for layer in self.layers:
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states, )
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = np.random.uniform(0, 1) dropout_probability = np.random.uniform(0, 1)
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False skip_the_layer = True if self.training and (
if not skip_the_layer:# or deepspeed_zero3_is_enabled: dropout_probability < self.config.layerdrop) else False
if not skip_the_layer: # or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync # under deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
@ -767,28 +805,32 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Layer):
return custom_forward return custom_forward
else: else:
layer_outputs = layer( layer_outputs = layer(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions hidden_states,
) attention_mask=attention_mask,
output_attentions=output_attentions)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if skip_the_layer: if skip_the_layer:
layer_outputs = (None, None) layer_outputs = (None, None)
if output_attentions: if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],) all_self_attentions = all_self_attentions + (layer_outputs[1], )
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states, )
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return tuple(
v
for v in
[hidden_states, all_hidden_states, all_self_attentions]
if v is not None)
return BaseModelOutput( return BaseModelOutput(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_self_attentions, attentions=all_self_attentions, )
)
class Wav2Vec2GumbelVectorQuantizer(nn.Layer): class Wav2Vec2GumbelVectorQuantizer(nn.Layer):
@ -810,9 +852,13 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Layer):
# storage for codebook variables (codewords) # storage for codebook variables (codewords)
self.codevectors = paddle.static.create_parameter( self.codevectors = paddle.static.create_parameter(
shape=[1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups], dtype='float32' shape=[
) 1, self.num_groups * self.num_vars,
self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars) config.codevector_dim // self.num_groups
],
dtype='float32')
self.weight_proj = nn.Linear(config.conv_dim[-1],
self.num_groups * self.num_vars)
# can be decayed for training # can be decayed for training
self.temperature = 2 self.temperature = 2
@ -826,7 +872,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Layer):
else: else:
marginal_probs = probs.mean(dim=0) marginal_probs = probs.mean(dim=0)
perplexity = paddle.exp(-paddle.sum(marginal_probs * paddle.log(marginal_probs + 1e-7), dim=-1)).sum() perplexity = paddle.exp(-paddle.sum(
marginal_probs * paddle.log(marginal_probs + 1e-7), dim=-1)).sum()
return perplexity return perplexity
def forward(self, hidden_states, mask_time_indices=None): def forward(self, hidden_states, mask_time_indices=None):
@ -834,35 +881,45 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Layer):
# project to codevector dim # project to codevector dim
hidden_states = self.weight_proj(hidden_states) hidden_states = self.weight_proj(hidden_states)
hidden_states = hidden_states.reshape((batch_size * sequence_length * self.num_groups, -1)) hidden_states = hidden_states.reshape(
(batch_size * sequence_length * self.num_groups, -1))
if self.training: if self.training:
# sample code vector probs via gumbel in differentiateable way # sample code vector probs via gumbel in differentiateable way
codevector_probs = nn.functional.gumbel_softmax( codevector_probs = nn.functional.gumbel_softmax(
hidden_states.float(), tau=self.temperature, hard=True hidden_states.float(), tau=self.temperature,
).type_as(hidden_states) hard=True).type_as(hidden_states)
# compute perplexity # compute perplexity
codevector_soft_dist = paddle.softmax( codevector_soft_dist = paddle.softmax(
hidden_states.reshape((batch_size * sequence_length, self.num_groups, -1)).float(), axis=-1 hidden_states.reshape((batch_size * sequence_length,
) self.num_groups, -1)).float(),
perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices) axis=-1)
perplexity = self._compute_perplexity(codevector_soft_dist,
mask_time_indices)
else: else:
# take argmax in non-differentiable way # take argmax in non-differentiable way
# comptute hard codevector distribution (one hot) # comptute hard codevector distribution (one hot)
codevector_idx = hidden_states.argmax(dim=-1) codevector_idx = hidden_states.argmax(dim=-1)
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_( codevector_probs = hidden_states.new_zeros(
-1, codevector_idx.reshape((-1, 1)), 1.0 *hidden_states.shape).scatter_(-1,
) codevector_idx.reshape((-1, 1)),
codevector_probs = codevector_probs.reshape((batch_size * sequence_length, self.num_groups, -1)) 1.0)
codevector_probs = codevector_probs.reshape(
perplexity = self._compute_perplexity(codevector_probs, mask_time_indices) (batch_size * sequence_length, self.num_groups, -1))
codevector_probs = codevector_probs.reshape((batch_size * sequence_length, -1)) perplexity = self._compute_perplexity(codevector_probs,
mask_time_indices)
codevector_probs = codevector_probs.reshape(
(batch_size * sequence_length, -1))
# use probs to retrieve codevectors # use probs to retrieve codevectors
codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors codevectors_per_group = codevector_probs.unsqueeze(
codevectors = codevectors_per_group.reshape((batch_size * sequence_length, self.num_groups, self.num_vars, -1)) -1) * self.codevectors
codevectors = codevectors.sum(-2).reshape((batch_size, sequence_length, -1)) codevectors = codevectors_per_group.reshape(
(batch_size * sequence_length, self.num_groups, self.num_vars, -1))
codevectors = codevectors.sum(-2).reshape(
(batch_size, sequence_length, -1))
return codevectors, perplexity return codevectors, perplexity
@ -878,7 +935,9 @@ class Wav2Vec2Adapter(nn.Layer):
else: else:
self.proj = self.proj_layer_norm = None self.proj = self.proj_layer_norm = None
self.layers = nn.LayerList(Wav2Vec2AdapterLayer(config) for _ in range(config.num_adapter_layers)) self.layers = nn.LayerList(
Wav2Vec2AdapterLayer(config)
for _ in range(config.num_adapter_layers))
self.layerdrop = config.layerdrop self.layerdrop = config.layerdrop
def forward(self, hidden_states): def forward(self, hidden_states):
@ -906,8 +965,7 @@ class Wav2Vec2AdapterLayer(nn.Layer):
2 * config.output_hidden_size, 2 * config.output_hidden_size,
config.adapter_kernel_size, config.adapter_kernel_size,
stride=config.adapter_stride, stride=config.adapter_stride,
padding=1, padding=1, )
)
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.conv(hidden_states) hidden_states = self.conv(hidden_states)
@ -927,7 +985,11 @@ class Wav2Vec2Model(nn.Layer):
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
# self.masked_spec_embed = nn.Parameter(paddle.Tensor(config.hidden_size).uniform_()) # self.masked_spec_embed = nn.Parameter(paddle.Tensor(config.hidden_size).uniform_())
#self.masked_spec_embed = paddle.uniform([config.hidden_size]) #self.masked_spec_embed = paddle.uniform([config.hidden_size])
self.masked_spec_embed = paddle.static.create_parameter(shape=[config.hidden_size], dtype='float32', default_initializer=paddle.nn.initializer.Uniform(low=0, high=1.0)) self.masked_spec_embed = paddle.static.create_parameter(
shape=[config.hidden_size],
dtype='float32',
default_initializer=paddle.nn.initializer.Uniform(
low=0, high=1.0))
if config.do_stable_layer_norm: if config.do_stable_layer_norm:
self.encoder = Wav2Vec2EncoderStableLayerNorm(config) self.encoder = Wav2Vec2EncoderStableLayerNorm(config)
else: else:
@ -948,9 +1010,8 @@ class Wav2Vec2Model(nn.Layer):
def _mask_hidden_states( def _mask_hidden_states(
self, self,
hidden_states: paddle.Tensor, hidden_states: paddle.Tensor,
mask_time_indices: Optional[paddle.Tensor] = None, mask_time_indices: Optional[paddle.Tensor]=None,
attention_mask: Optional[paddle.Tensor] = None, attention_mask: Optional[paddle.Tensor]=None, ):
):
""" """
Masks extracted features along time axis and/or along feature axis according to Masks extracted features along time axis and/or along feature axis according to
[SpecAugment](https://arxiv.org/abs/1904.08779). [SpecAugment](https://arxiv.org/abs/1904.08779).
@ -963,17 +1024,19 @@ class Wav2Vec2Model(nn.Layer):
batch_size, sequence_length, hidden_size = hidden_states.shape batch_size, sequence_length, hidden_size = hidden_states.shape
if mask_time_indices is not None: if mask_time_indices is not None:
# apply SpecAugment along time axis with given mask_time_indices # apply SpecAugment along time axis with given mask_time_indices
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) hidden_states[mask_time_indices] = self.masked_spec_embed.to(
hidden_states.dtype)
elif self.config.mask_time_prob > 0 and self.training: elif self.config.mask_time_prob > 0 and self.training:
mask_time_indices = _compute_mask_indices( mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length), (batch_size, sequence_length),
mask_prob=self.config.mask_time_prob, mask_prob=self.config.mask_time_prob,
mask_length=self.config.mask_time_length, mask_length=self.config.mask_time_length,
attention_mask=attention_mask, attention_mask=attention_mask,
min_masks=self.config.mask_time_min_masks, min_masks=self.config.mask_time_min_masks, )
) mask_time_indices = paddle.to_tensor(
mask_time_indices = paddle.to_tensor(mask_time_indices, dtype=paddle.bool) mask_time_indices, dtype=paddle.bool)
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) hidden_states[mask_time_indices] = self.masked_spec_embed.to(
hidden_states.dtype)
if self.config.mask_feature_prob > 0 and self.training: if self.config.mask_feature_prob > 0 and self.training:
# generate indices & apply SpecAugment along feature axis # generate indices & apply SpecAugment along feature axis
@ -981,10 +1044,11 @@ class Wav2Vec2Model(nn.Layer):
(batch_size, hidden_size), (batch_size, hidden_size),
mask_prob=self.config.mask_feature_prob, mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length, mask_length=self.config.mask_feature_length,
min_masks=self.config.mask_feature_min_masks, min_masks=self.config.mask_feature_min_masks, )
) mask_feature_indices = paddle.to_tensor(
mask_feature_indices = paddle.to_tensor(mask_feature_indices, dtype=paddle.bool) mask_feature_indices, dtype=paddle.bool)
mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) mask_feature_indices = mask_feature_indices[:, None].expand(
-1, sequence_length, -1)
hidden_states[mask_feature_indices] = 0 hidden_states[mask_feature_indices] = 0
return hidden_states return hidden_states
@ -992,16 +1056,16 @@ class Wav2Vec2Model(nn.Layer):
def forward( def forward(
self, self,
input_values: Optional[paddle.Tensor], input_values: Optional[paddle.Tensor],
attention_mask: Optional[paddle.Tensor] = None, attention_mask: Optional[paddle.Tensor]=None,
mask_time_indices: Optional[paddle.Tensor] = None, mask_time_indices: Optional[paddle.Tensor]=None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool]=None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool]=None,
return_dict: Optional[bool] = None, return_dict: Optional[bool]=None,
) -> Union[Tuple, Wav2Vec2BaseModelOutput]: ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (output_hidden_states
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states if output_hidden_states is not None else
) self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
extract_features = self.feature_extractor(input_values) extract_features = self.feature_extractor(input_values)
extract_features = extract_features.transpose([0, 2, 1]) extract_features = extract_features.transpose([0, 2, 1])
@ -1009,20 +1073,20 @@ class Wav2Vec2Model(nn.Layer):
if attention_mask is not None: if attention_mask is not None:
# compute reduced attention_mask corresponding to feature vectors # compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask( attention_mask = self._get_feature_vector_attention_mask(
extract_features.shape[1], attention_mask, add_adapter=False extract_features.shape[1], attention_mask, add_adapter=False)
) hidden_states, extract_features = self.feature_projection(
hidden_states, extract_features = self.feature_projection(extract_features) extract_features)
hidden_states = self._mask_hidden_states( hidden_states = self._mask_hidden_states(
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask hidden_states,
) mask_time_indices=mask_time_indices,
attention_mask=attention_mask)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict, )
)
hidden_states = encoder_outputs[0] hidden_states = encoder_outputs[0]
@ -1036,8 +1100,7 @@ class Wav2Vec2Model(nn.Layer):
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
extract_features=extract_features, extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions, )
)
def post_init(self): def post_init(self):
""" """
@ -1048,8 +1111,10 @@ class Wav2Vec2Model(nn.Layer):
# self._backward_compatibility_gradient_checkpointing() # self._backward_compatibility_gradient_checkpointing()
pass pass
class Wav2Vec2ConfigPure(): class Wav2Vec2ConfigPure():
model_type = "wav2vec2" model_type = "wav2vec2"
def __init__(self, config): def __init__(self, config):
self.output_attentions = False self.output_attentions = False
self.output_hidden_states = False self.output_hidden_states = False
@ -1084,17 +1149,14 @@ class Wav2Vec2ConfigPure():
self.do_stable_layer_norm = config.do_stable_layer_norm self.do_stable_layer_norm = config.do_stable_layer_norm
self.use_weighted_layer_sum = config.use_weighted_layer_sum self.use_weighted_layer_sum = config.use_weighted_layer_sum
if ( if ((len(self.conv_stride) != self.num_feat_extract_layers) or
(len(self.conv_stride) != self.num_feat_extract_layers) (len(self.conv_kernel) != self.num_feat_extract_layers) or
or (len(self.conv_kernel) != self.num_feat_extract_layers) (len(self.conv_dim) != self.num_feat_extract_layers)):
or (len(self.conv_dim) != self.num_feat_extract_layers)
):
raise ValueError( raise ValueError(
"Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` ==" "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
" `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) =" " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`," f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
f" `len(config.conv_kernel) = {len(self.conv_kernel)}`." f" `len(config.conv_kernel) = {len(self.conv_kernel)}`.")
)
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779 # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
self.apply_spec_augment = config.apply_spec_augment self.apply_spec_augment = config.apply_spec_augment

@ -7,10 +7,8 @@ Authors
* Samuele Cornell 2020 * Samuele Cornell 2020
* Sarthak Yadav 2022 * Sarthak Yadav 2022
""" """
import paddle
import math
from packaging import version
import numpy as np import numpy as np
import paddle
def blackman_window(window_length, periodic=True): def blackman_window(window_length, periodic=True):
@ -97,8 +95,7 @@ def convolve1d(
stride=1, stride=1,
groups=1, groups=1,
use_fft=False, use_fft=False,
rotation_index=0, rotation_index=0, ):
):
"""Use paddle.nn.functional to perform 1d padding and conv. """Use paddle.nn.functional to perform 1d padding and conv.
Arguments Arguments
--------- ---------
@ -150,8 +147,7 @@ def convolve1d(
# Padding can be a tuple (left_pad, right_pad) or an int # Padding can be a tuple (left_pad, right_pad) or an int
if isinstance(padding, tuple): if isinstance(padding, tuple):
waveform = paddle.nn.functional.pad( waveform = paddle.nn.functional.pad(
x=waveform, pad=padding, mode=pad_type, data_format='NCL' x=waveform, pad=padding, mode=pad_type, data_format='NCL')
)
# This approach uses FFT, which is more efficient if the kernel is large # This approach uses FFT, which is more efficient if the kernel is large
if use_fft: if use_fft:
@ -165,9 +161,7 @@ def convolve1d(
# Perform rotation to ensure alignment # Perform rotation to ensure alignment
zeros = paddle.zeros( zeros = paddle.zeros(
[kernel.shape[0], kernel.shape[1], zero_length], [kernel.shape[0], kernel.shape[1], zero_length], dtype=kernel.dtype)
dtype=kernel.dtype
)
after_index = kernel[..., rotation_index:] after_index = kernel[..., rotation_index:]
before_index = kernel[..., :rotation_index] before_index = kernel[..., :rotation_index]
kernel = paddle.concat((after_index, zeros, before_index), axis=-1) kernel = paddle.concat((after_index, zeros, before_index), axis=-1)
@ -185,12 +179,12 @@ def convolve1d(
weight=kernel, weight=kernel,
stride=stride, stride=stride,
groups=groups, groups=groups,
padding=padding if not isinstance(padding, tuple) else 0, padding=padding if not isinstance(padding, tuple) else 0, )
)
# Return time dimension to the second dimension. # Return time dimension to the second dimension.
return convolved.transpose([0, 2, 1]) return convolved.transpose([0, 2, 1])
def notch_filter(notch_freq, filter_width=101, notch_width=0.05): def notch_filter(notch_freq, filter_width=101, notch_width=0.05):
"""Returns a notch filter constructed from a high-pass and low-pass filter. """Returns a notch filter constructed from a high-pass and low-pass filter.
(from https://tomroelandts.com/articles/ (from https://tomroelandts.com/articles/
@ -224,7 +218,8 @@ def notch_filter(notch_freq, filter_width=101, notch_width=0.05):
return paddle.sin(x) / x return paddle.sin(x) / x
# The zero is at the middle index # The zero is at the middle index
return paddle.concat([_sinc(x[:pad]), paddle.ones([1]), _sinc(x[pad + 1 :])]) return paddle.concat(
[_sinc(x[:pad]), paddle.ones([1]), _sinc(x[pad + 1:])])
# Compute a low-pass filter with cutoff frequency notch_freq. # Compute a low-pass filter with cutoff frequency notch_freq.
hlpf = sinc(3 * (notch_freq - notch_width) * inputs) hlpf = sinc(3 * (notch_freq - notch_width) * inputs)
@ -239,4 +234,3 @@ def notch_filter(notch_freq, filter_width=101, notch_width=0.05):
# Adding filters creates notch filter # Adding filters creates notch filter
return (hlpf + hhpf).view(1, -1, 1) return (hlpf + hhpf).view(1, -1, 1)

@ -1,11 +1,12 @@
import math import math
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F
from paddlespeech.s2t.models.wav2vec2.processing.signal_processing import ( from paddlespeech.s2t.models.wav2vec2.processing.signal_processing import compute_amplitude
compute_amplitude, from paddlespeech.s2t.models.wav2vec2.processing.signal_processing import convolve1d
convolve1d, from paddlespeech.s2t.models.wav2vec2.processing.signal_processing import notch_filter
notch_filter)
class SpeedPerturb(nn.Layer): class SpeedPerturb(nn.Layer):
"""Slightly speed up or slow down an audio signal. """Slightly speed up or slow down an audio signal.
@ -36,8 +37,10 @@ class SpeedPerturb(nn.Layer):
""" """
def __init__( def __init__(
self, orig_freq, speeds=[90, 100, 110], perturb_prob=1.0, self,
): orig_freq,
speeds=[90, 100, 110],
perturb_prob=1.0, ):
super().__init__() super().__init__()
self.orig_freq = orig_freq self.orig_freq = orig_freq
self.speeds = speeds self.speeds = speeds
@ -73,11 +76,12 @@ class SpeedPerturb(nn.Layer):
return waveform.clone() return waveform.clone()
# Perform a random perturbation # Perform a random perturbation
self.samp_index = paddle.randint(len(self.speeds), shape=(1,))[0] self.samp_index = paddle.randint(len(self.speeds), shape=(1, ))[0]
perturbed_waveform = self.resamplers[self.samp_index](waveform) perturbed_waveform = self.resamplers[self.samp_index](waveform)
return perturbed_waveform return perturbed_waveform
class Resample(nn.Layer): class Resample(nn.Layer):
"""This class resamples an audio signal using sinc-based interpolation. """This class resamples an audio signal using sinc-based interpolation.
@ -94,9 +98,12 @@ class Resample(nn.Layer):
Controls the sharpness of the filter, larger numbers result in a Controls the sharpness of the filter, larger numbers result in a
sharper filter, but they are less efficient. Values from 4 to 10 are allowed. sharper filter, but they are less efficient. Values from 4 to 10 are allowed.
""" """
def __init__( def __init__(
self, orig_freq=16000, new_freq=16000, lowpass_filter_width=6, self,
): orig_freq=16000,
new_freq=16000,
lowpass_filter_width=6, ):
super().__init__() super().__init__()
self.orig_freq = orig_freq self.orig_freq = orig_freq
self.new_freq = new_freq self.new_freq = new_freq
@ -193,8 +200,7 @@ class Resample(nn.Layer):
window_size = self.weights.shape[1] window_size = self.weights.shape[1]
tot_output_samp = self._output_samples(wave_len) tot_output_samp = self._output_samples(wave_len)
resampled_waveform = paddle.zeros( resampled_waveform = paddle.zeros(
(batch_size, num_channels, tot_output_samp) (batch_size, num_channels, tot_output_samp))
)
# self.weights = self.weights.to(waveforms.device) # self.weights = self.weights.to(waveforms.device)
# Check weights are on correct device # Check weights are on correct device
@ -222,28 +228,25 @@ class Resample(nn.Layer):
right_padding = max(0, end_index + 1 - current_wave_len) right_padding = max(0, end_index + 1 - current_wave_len)
left_padding = max(0, -first_index) left_padding = max(0, -first_index)
wave_to_conv = paddle.nn.functional.pad( wave_to_conv = paddle.nn.functional.pad(
wave_to_conv, (left_padding, right_padding), data_format='NCL' wave_to_conv, (left_padding, right_padding), data_format='NCL')
)
conv_wave = paddle.nn.functional.conv1d( conv_wave = paddle.nn.functional.conv1d(
x=wave_to_conv, x=wave_to_conv,
weight=self.weights[i].repeat(num_channels, 1, 1), weight=self.weights[i].repeat(num_channels, 1, 1),
stride=self.conv_stride, stride=self.conv_stride,
groups=num_channels, groups=num_channels, )
)
# we want conv_wave[:, i] to be at # we want conv_wave[:, i] to be at
# output[:, i + n*conv_transpose_stride] # output[:, i + n*conv_transpose_stride]
dilated_conv_wave = paddle.nn.functional.conv1d_transpose( dilated_conv_wave = paddle.nn.functional.conv1d_transpose(
conv_wave, eye, stride=self.conv_transpose_stride conv_wave, eye, stride=self.conv_transpose_stride)
)
# pad dilated_conv_wave so it reaches the output length if needed. # pad dilated_conv_wave so it reaches the output length if needed.
left_padding = i left_padding = i
previous_padding = left_padding + dilated_conv_wave.shape[-1] previous_padding = left_padding + dilated_conv_wave.shape[-1]
right_padding = max(0, tot_output_samp - previous_padding) right_padding = max(0, tot_output_samp - previous_padding)
dilated_conv_wave = paddle.nn.functional.pad( dilated_conv_wave = paddle.nn.functional.pad(
dilated_conv_wave, (left_padding, right_padding), data_format='NCL' dilated_conv_wave, (left_padding, right_padding),
) data_format='NCL')
dilated_conv_wave = dilated_conv_wave[..., :tot_output_samp] dilated_conv_wave = dilated_conv_wave[..., :tot_output_samp]
resampled_waveform += dilated_conv_wave resampled_waveform += dilated_conv_wave
@ -326,9 +329,7 @@ class Resample(nn.Layer):
window_width = self.lowpass_filter_width / (2.0 * lowpass_cutoff) window_width = self.lowpass_filter_width / (2.0 * lowpass_cutoff)
assert lowpass_cutoff < min(self.orig_freq, self.new_freq) / 2 assert lowpass_cutoff < min(self.orig_freq, self.new_freq) / 2
output_t = paddle.arange( output_t = paddle.arange(start=0.0, end=self.output_samples)
start=0.0, end=self.output_samples
)
output_t /= self.new_freq output_t /= self.new_freq
min_t = output_t - window_width min_t = output_t - window_width
max_t = output_t + window_width max_t = output_t + window_width
@ -346,23 +347,16 @@ class Resample(nn.Layer):
inside_window_indices = delta_t.abs() < (window_width) inside_window_indices = delta_t.abs() < (window_width)
# raised-cosine (Hanning) window with width `window_width` # raised-cosine (Hanning) window with width `window_width`
weights[inside_window_indices] = 0.5 * ( weights[inside_window_indices] = 0.5 * (1 + paddle.cos(
1 2 * math.pi * lowpass_cutoff / self.lowpass_filter_width *
+ paddle.cos( delta_t[inside_window_indices]))
2
* math.pi
* lowpass_cutoff
/ self.lowpass_filter_width
* delta_t[inside_window_indices]
)
)
t_eq_zero_indices = delta_t == 0.0 t_eq_zero_indices = delta_t == 0.0
t_not_eq_zero_indices = ~t_eq_zero_indices t_not_eq_zero_indices = ~t_eq_zero_indices
# sinc filter function # sinc filter function
weights[t_not_eq_zero_indices] *= paddle.sin( weights[t_not_eq_zero_indices] *= paddle.sin(
2 * math.pi * lowpass_cutoff * delta_t[t_not_eq_zero_indices] 2 * math.pi * lowpass_cutoff * delta_t[t_not_eq_zero_indices]) / (
) / (math.pi * delta_t[t_not_eq_zero_indices]) math.pi * delta_t[t_not_eq_zero_indices])
# limit of the function at t = 0 # limit of the function at t = 0
weights[t_eq_zero_indices] *= 2 * lowpass_cutoff weights[t_eq_zero_indices] *= 2 * lowpass_cutoff
@ -411,8 +405,7 @@ class DropFreq(nn.Layer):
drop_count_low=1, drop_count_low=1,
drop_count_high=2, drop_count_high=2,
drop_width=0.05, drop_width=0.05,
drop_prob=1, drop_prob=1, ):
):
super().__init__() super().__init__()
self.drop_freq_low = drop_freq_low self.drop_freq_low = drop_freq_low
self.drop_freq_high = drop_freq_high self.drop_freq_high = drop_freq_high
@ -443,14 +436,14 @@ class DropFreq(nn.Layer):
# Pick number of frequencies to drop # Pick number of frequencies to drop
drop_count = paddle.randint( drop_count = paddle.randint(
low=self.drop_count_low, high=self.drop_count_high + 1, shape=(1,), low=self.drop_count_low,
) high=self.drop_count_high + 1,
shape=(1, ), )
# Pick a frequency to drop # Pick a frequency to drop
drop_range = self.drop_freq_high - self.drop_freq_low drop_range = self.drop_freq_high - self.drop_freq_low
drop_frequency = ( drop_frequency = (
paddle.rand(drop_count) * drop_range + self.drop_freq_low paddle.rand(drop_count) * drop_range + self.drop_freq_low)
)
# Filter parameters # Filter parameters
filter_length = 101 filter_length = 101
pad = filter_length // 2 pad = filter_length // 2
@ -461,8 +454,9 @@ class DropFreq(nn.Layer):
# Subtract each frequency # Subtract each frequency
for frequency in drop_frequency: for frequency in drop_frequency:
notch_kernel = notch_filter( notch_kernel = notch_filter(
frequency, filter_length, self.drop_width, frequency,
) filter_length,
self.drop_width, )
drop_filter = convolve1d(drop_filter, notch_kernel, pad) drop_filter = convolve1d(drop_filter, notch_kernel, pad)
# Apply filter # Apply filter
@ -471,6 +465,7 @@ class DropFreq(nn.Layer):
# Remove channels dimension if added # Remove channels dimension if added
return dropped_waveform.squeeze(-1) return dropped_waveform.squeeze(-1)
class DropChunk(nn.Layer): class DropChunk(nn.Layer):
"""This class drops portions of the input signal. """This class drops portions of the input signal.
Using `DropChunk` as an augmentation strategy helps a models learn to rely Using `DropChunk` as an augmentation strategy helps a models learn to rely
@ -523,8 +518,7 @@ class DropChunk(nn.Layer):
drop_start=0, drop_start=0,
drop_end=None, drop_end=None,
drop_prob=1, drop_prob=1,
noise_factor=0.0, noise_factor=0.0, ):
):
super().__init__() super().__init__()
self.drop_length_low = drop_length_low self.drop_length_low = drop_length_low
self.drop_length_high = drop_length_high self.drop_length_high = drop_length_high
@ -580,8 +574,7 @@ class DropChunk(nn.Layer):
drop_times = paddle.randint( drop_times = paddle.randint(
low=self.drop_count_low, low=self.drop_count_low,
high=self.drop_count_high + 1, high=self.drop_count_high + 1,
shape=(batch_size,), shape=(batch_size, ), )
)
# Iterate batch to set mask # Iterate batch to set mask
for i in range(batch_size): for i in range(batch_size):
@ -592,8 +585,7 @@ class DropChunk(nn.Layer):
length = paddle.randint( length = paddle.randint(
low=self.drop_length_low, low=self.drop_length_low,
high=self.drop_length_high + 1, high=self.drop_length_high + 1,
shape=(drop_times[i],), shape=(drop_times[i], ), )
)
# Compute range of starting locations # Compute range of starting locations
start_min = self.drop_start start_min = self.drop_start
@ -608,15 +600,16 @@ class DropChunk(nn.Layer):
# Pick starting locations # Pick starting locations
start = paddle.randint( start = paddle.randint(
low=start_min, high=start_max + 1, shape=(drop_times[i],), low=start_min,
) high=start_max + 1,
shape=(drop_times[i], ), )
end = start + length end = start + length
# Update waveform # Update waveform
if not self.noise_factor: if not self.noise_factor:
for j in range(drop_times[i]): for j in range(drop_times[i]):
dropped_waveform[i, start[j] : end[j]] = 0.0 dropped_waveform[i, start[j]:end[j]] = 0.0
else: else:
# Uniform distribution of -2 to +2 * avg amplitude should # Uniform distribution of -2 to +2 * avg amplitude should
# preserve the average for normalization # preserve the average for normalization
@ -625,7 +618,7 @@ class DropChunk(nn.Layer):
# zero-center the noise distribution # zero-center the noise distribution
noise_vec = paddle.rand([length[j]]) noise_vec = paddle.rand([length[j]])
noise_vec = 2 * noise_max * noise_vec - noise_max noise_vec = 2 * noise_max * noise_vec - noise_max
dropped_waveform[i, start[j] : end[j]] = noise_vec dropped_waveform[i, start[j]:end[j]] = noise_vec
return dropped_waveform return dropped_waveform
@ -691,25 +684,21 @@ class TimeDomainSpecAugment(nn.Layer):
drop_chunk_count_high=5, drop_chunk_count_high=5,
drop_chunk_length_low=1000, drop_chunk_length_low=1000,
drop_chunk_length_high=2000, drop_chunk_length_high=2000,
drop_chunk_noise_factor=0, drop_chunk_noise_factor=0, ):
):
super().__init__() super().__init__()
self.speed_perturb = SpeedPerturb( self.speed_perturb = SpeedPerturb(
perturb_prob=perturb_prob, orig_freq=sample_rate, speeds=speeds perturb_prob=perturb_prob, orig_freq=sample_rate, speeds=speeds)
)
self.drop_freq = DropFreq( self.drop_freq = DropFreq(
drop_prob=drop_freq_prob, drop_prob=drop_freq_prob,
drop_count_low=drop_freq_count_low, drop_count_low=drop_freq_count_low,
drop_count_high=drop_freq_count_high, drop_count_high=drop_freq_count_high, )
)
self.drop_chunk = DropChunk( self.drop_chunk = DropChunk(
drop_prob=drop_chunk_prob, drop_prob=drop_chunk_prob,
drop_count_low=drop_chunk_count_low, drop_count_low=drop_chunk_count_low,
drop_count_high=drop_chunk_count_high, drop_count_high=drop_chunk_count_high,
drop_length_low=drop_chunk_length_low, drop_length_low=drop_chunk_length_low,
drop_length_high=drop_chunk_length_high, drop_length_high=drop_chunk_length_high,
noise_factor=drop_chunk_noise_factor, noise_factor=drop_chunk_noise_factor, )
)
def forward(self, waveforms, lengths): def forward(self, waveforms, lengths):
"""Returns the distorted waveforms. """Returns the distorted waveforms.

@ -1,25 +1,19 @@
import numpy as np from collections import defaultdict
import os
from typing import Dict from typing import Dict
from typing import List from typing import List
from typing import Optional
from typing import Tuple from typing import Tuple
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2 import Wav2Vec2ConfigPure from paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2 import Wav2Vec2ConfigPure
from paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2 import Wav2Vec2Model from paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2 import Wav2Vec2Model
from paddlespeech.s2t.modules.mask import make_pad_mask
from paddlespeech.s2t.utils.utility import log_add
from collections import defaultdict
from paddlespeech.s2t.models.wav2vec2.modules.VanillaNN import VanillaNN from paddlespeech.s2t.models.wav2vec2.modules.VanillaNN import VanillaNN
from paddlespeech.s2t.modules.ctc import CTCDecoderBase as CTC from paddlespeech.s2t.modules.ctc import CTCDecoderBase as CTC
from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank
from yacs.config import CfgNode from paddlespeech.s2t.utils.utility import log_add
class Wav2vec2ASR(nn.Layer): class Wav2vec2ASR(nn.Layer):
def __init__(self, config: dict): def __init__(self, config: dict):
@ -36,8 +30,16 @@ class Wav2vec2ASR(nn.Layer):
for parm in wav2vec2.parameters(): for parm in wav2vec2.parameters():
parm.trainable = False parm.trainable = False
self.wav2vec2 = wav2vec2 self.wav2vec2 = wav2vec2
self.enc = VanillaNN(input_shape=[None,None,wav2vec2_config.hidden_size], activation=nn.LeakyReLU, dnn_blocks=config.dnn_blocks, dnn_neurons=config.dnn_neurons) self.enc = VanillaNN(
self.ctc = CTC(odim=config.output_dim, enc_n_units=config.dnn_neurons, blank_id=config.blank_id, dropout_rate=config.ctc_dropout_rate, reduction=True) input_shape=[None, None, wav2vec2_config.hidden_size],
activation=nn.LeakyReLU,
dnn_blocks=config.dnn_blocks,
dnn_neurons=config.dnn_neurons)
self.ctc = CTC(odim=config.output_dim,
enc_n_units=config.dnn_neurons,
blank_id=config.blank_id,
dropout_rate=config.ctc_dropout_rate,
reduction=True)
def forward(self, wav, wavs_lens_rate, target, target_lens_rate): def forward(self, wav, wavs_lens_rate, target, target_lens_rate):
if self.normalize_wav: if self.normalize_wav:
@ -51,7 +53,8 @@ class Wav2vec2ASR(nn.Layer):
x = self.enc(feats) x = self.enc(feats)
x_lens = (wavs_lens_rate * x.shape[1]).round().astype(paddle.int64) x_lens = (wavs_lens_rate * x.shape[1]).round().astype(paddle.int64)
target_lens = (target_lens_rate * target.shape[1]).round().astype(paddle.int64) target_lens = (target_lens_rate *
target.shape[1]).round().astype(paddle.int64)
ctc_loss = self.ctc(x, x_lens, target, target_lens) ctc_loss = self.ctc(x, x_lens, target, target_lens)
return ctc_loss return ctc_loss
@ -63,7 +66,8 @@ class Wav2vec2ASR(nn.Layer):
decoding_method: str, decoding_method: str,
beam_size: int): beam_size: int):
batch_size = feats.shape[0] batch_size = feats.shape[0]
if decoding_method is 'ctc_prefix_beam_search' and batch_size > 1:
if decoding_method == 'ctc_prefix_beam_search' and batch_size > 1:
logger.error( logger.error(
f'decoding mode {decoding_method} must be running with batch_size == 1' f'decoding mode {decoding_method} must be running with batch_size == 1'
) )
@ -79,13 +83,12 @@ class Wav2vec2ASR(nn.Layer):
# with other batch decoding mode # with other batch decoding mode
elif decoding_method == 'ctc_prefix_beam_search': elif decoding_method == 'ctc_prefix_beam_search':
assert feats.shape[0] == 1 assert feats.shape[0] == 1
hyp = self.ctc_prefix_beam_search( hyp = self.ctc_prefix_beam_search(feats, beam_size)
feats,
beam_size)
res = [text_feature.defeaturize(hyp)] res = [text_feature.defeaturize(hyp)]
res_tokenids = [hyp] res_tokenids = [hyp]
else: else:
raise ValueError(f"wav2vec2 not support decoding method: {decoding_method}") raise ValueError(
f"wav2vec2 not support decoding method: {decoding_method}")
return res, res_tokenids return res, res_tokenids
@ -94,8 +97,7 @@ class Wav2vec2ASR(nn.Layer):
model = cls(config) model = cls(config)
return model return model
def ctc_greedy_search( def ctc_greedy_search(self, wav) -> List[List[int]]:
self, wav) -> List[List[int]]:
""" Apply CTC greedy search """ Apply CTC greedy search
Args: Args:
speech (paddle.Tensor): (batch, max_len) speech (paddle.Tensor): (batch, max_len)
@ -104,7 +106,7 @@ class Wav2vec2ASR(nn.Layer):
List[List[int]]: best path result List[List[int]]: best path result
""" """
batch_size = wav.shape[0] batch_size = wav.shape[0]
wav = wav[:,:,0] wav = wav[:, :, 0]
if self.normalize_wav: if self.normalize_wav:
wav = F.layer_norm(wav, wav.shape[1:]) wav = F.layer_norm(wav, wav.shape[1:])
# Extract wav2vec output # Extract wav2vec output
@ -124,7 +126,10 @@ class Wav2vec2ASR(nn.Layer):
return hyps return hyps
def _ctc_prefix_beam_search( def _ctc_prefix_beam_search(
self, wav, beam_size, blank_id: int=0, ) -> Tuple[List[Tuple[int, float]], paddle.Tensor]: self,
wav,
beam_size,
blank_id: int=0, ) -> Tuple[List[Tuple[int, float]], paddle.Tensor]:
""" CTC prefix beam search inner implementation """ CTC prefix beam search inner implementation
Args: Args:
speech (paddle.Tensor): (batch, max_len, feat_dim) speech (paddle.Tensor): (batch, max_len, feat_dim)
@ -142,7 +147,7 @@ class Wav2vec2ASR(nn.Layer):
paddle.Tensor: encoder output, (1, max_len, encoder_dim), paddle.Tensor: encoder output, (1, max_len, encoder_dim),
it will be used for rescoring in attention rescoring mode it will be used for rescoring in attention rescoring mode
""" """
wav = wav[:,:,0] wav = wav[:, :, 0]
if self.normalize_wav: if self.normalize_wav:
wav = F.layer_norm(wav, wav.shape[1:]) wav = F.layer_norm(wav, wav.shape[1:])
@ -219,29 +224,5 @@ class Wav2vec2ASR(nn.Layer):
Returns: Returns:
List[int]: CTC prefix beam search nbest results List[int]: CTC prefix beam search nbest results
""" """
hyps = self._ctc_prefix_beam_search( hyps = self._ctc_prefix_beam_search(wav, beam_size)
wav, beam_size)
return hyps[0][0] return hyps[0][0]
# @jit.to_static
# def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
# """ Export interface for c++ call, apply linear transform and log
# softmax before ctc
# Args:
# xs (paddle.Tensor): encoder output, (B, T, D)
# Returns:
# paddle.Tensor: activation before ctc
# """
# return self.ctc.log_softmax(xs)
# def _get_data(self):
# data_dir = "data"
# wavs = np.load(os.path.join(data_dir, "wavs.npy"))
# wavs_lens = np.load(os.path.join(data_dir, "wavs_lens.npy"))
# tokens = np.load(os.path.join(data_dir, "tokens.npy"))
# tokens_lens = np.load(os.path.join(data_dir, "tokens_lens.npy"))
# batch = (paddle.to_tensor(wavs), paddle.to_tensor(wavs_lens, dtype='float32'),
# paddle.to_tensor(tokens, dtype='int32'), paddle.to_tensor(tokens_lens, dtype='float32'))
# return batch

Loading…
Cancel
Save