format wav2vec2 demo

pull/2518/head
tianhao zhang 2 years ago
parent 6e429f0513
commit 19180d359d

@ -33,7 +33,7 @@ filename =
# Specify a list of codes to ignore.
ignore =
W503
E252,E262,E127,E265,E126,E266,E241,E261,E128,E125
E252,E262,E127,E265,E126,E266,E241,E261,E128,E125,E129
W291,W293,W605
E203,E305,E402,E501,E721,E741,F403,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,
# shebang has extra meaning in fbcode lints, so I think it's not worth trying

@ -3,7 +3,7 @@
* asr0 - deepspeech2 Streaming/Non-Streaming
* asr1 - transformer/conformer Streaming/Non-Streaming
* asr2 - transformer/conformer Streaming/Non-Streaming with Kaldi feature
* asr3 - wav2vecASR, ASR model with pre-trained wav2vec2 and CTC
## Data
| Data Subset | Duration in Seconds |

@ -382,6 +382,36 @@ class LogMelSpectrogramKaldi():
return mat
class WavProcess():
def __init__(self, dither=0.1):
"""
Args:
dither (float): Dithering constant
Returns:
"""
self.dither = dither
def __call__(self, x, train):
"""
Args:
x (np.ndarray): shape (Ti,)
train (bool): True, train mode.
Raises:
ValueError: not support (Ti, C)
Returns:
np.ndarray: (T, D)
"""
dither = self.dither if train else 0.0
if x.ndim != 1:
raise ValueError("Not support x: [Time, Channel]")
waveform = np.expand_dims(x, -1)
return waveform
class LogMelSpectrogramKaldi_decay():
def __init__(
self,

@ -41,6 +41,7 @@ import_alias = dict(
utterance_cmvn="paddlespeech.audio.transform.cmvn:UtteranceCMVN",
fbank="paddlespeech.audio.transform.spectrogram:LogMelSpectrogram",
spectrogram="paddlespeech.audio.transform.spectrogram:Spectrogram",
wav_process="paddlespeech.audio.transform.spectrogram:WavProcess",
stft="paddlespeech.audio.transform.spectrogram:Stft",
istft="paddlespeech.audio.transform.spectrogram:IStft",
stft2fbank="paddlespeech.audio.transform.spectrogram:Stft2LogMelSpectrogram",

@ -27,6 +27,7 @@ from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import UpdateConfig
logger = Log(__name__).getlog()
class Wav2vec2Infer():
def __init__(self, config, args):
self.args = args
@ -34,8 +35,7 @@ class Wav2vec2Infer():
self.audio_file = args.audio_file
self.text_feature = TextFeaturizer(
unit_type=config.unit_type,
vocab=config.vocab_filepath)
unit_type=config.unit_type, vocab=config.vocab_filepath)
paddle.set_device('gpu' if self.args.ngpu > 0 else 'cpu')
# model
@ -63,10 +63,10 @@ class Wav2vec2Infer():
xs = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0)
decode_config = self.config.decode
result_transcripts, result_tokenids = self.model.decode(
xs,
text_feature=self.text_feature,
decoding_method=decode_config.decoding_method,
beam_size=decode_config.beam_size)
xs,
text_feature=self.text_feature,
decoding_method=decode_config.decoding_method,
beam_size=decode_config.beam_size)
rsl = result_transcripts[0]
utt = Path(self.audio_file).name
logger.info(f"hyp: {utt} {rsl}")

@ -18,38 +18,38 @@ import time
from collections import defaultdict
from collections import OrderedDict
from contextlib import nullcontext
from paddlespeech.s2t.utils import mp_tools
import jsonlines
import numpy as np
import paddle
from paddle import distributed as dist
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import BatchDataLoader
from paddlespeech.s2t.io.dataloader import StreamDataLoader
from paddlespeech.s2t.io.dataloader import DataLoaderFactory
from paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR import Wav2vec2ASR
from paddlespeech.s2t.io.dataloader import StreamDataLoader
from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import TimeDomainSpecAugment
from paddlespeech.s2t.utils import error_rate
from paddlespeech.s2t.models.wav2vec2.wav2vec2_ASR import Wav2vec2ASR
from paddlespeech.s2t.training.optimizer import OptimizerFactory
from paddlespeech.s2t.training.reporter import ObsScope
from paddlespeech.s2t.training.reporter import report
from paddlespeech.s2t.training.scheduler import LRSchedulerFactory
from paddlespeech.s2t.training.timer import Timer
from paddlespeech.s2t.training.trainer import Trainer
from paddlespeech.s2t.utils.utility import UpdateConfig
from paddlespeech.s2t.utils import error_rate
from paddlespeech.s2t.utils import layer_tools
from paddlespeech.s2t.utils import mp_tools
from paddlespeech.s2t.utils.log import Log
from paddlespeech.s2t.utils.utility import UpdateConfig
logger = Log(__name__).getlog()
class Wav2Vec2ASRTrainer(Trainer):
def __init__(self, config, args):
super().__init__(config, args)
self.avg_train_loss = 0
def train_batch(self, batch_index, batch, msg):
train_conf = self.config
start = time.time()
@ -58,7 +58,7 @@ class Wav2Vec2ASRTrainer(Trainer):
utt, wav, wavs_lens, target, target_lens = batch
wavs_lens_rate = wavs_lens / wav.shape[1]
target_lens_rate = target_lens / target.shape[1]
wav = wav[:,:,0]
wav = wav[:, :, 0]
wav = self.speech_augmentation(wav, wavs_lens_rate)
loss = self.model(wav, wavs_lens_rate, target, target_lens_rate)
# pring(wav, wavs_lens_rate, target, target_lens_rate)
@ -108,7 +108,8 @@ class Wav2Vec2ASRTrainer(Trainer):
def valid(self):
self.model.eval()
if not self.use_streamdata:
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
logger.info(
f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list)
num_seen_utts = 1
total_loss = 0.0
@ -116,7 +117,7 @@ class Wav2Vec2ASRTrainer(Trainer):
utt, wav, wavs_lens, target, target_lens = batch
wavs_lens_rate = wavs_lens / wav.shape[1]
target_lens_rate = target_lens / target.shape[1]
wav = wav[:,:,0]
wav = wav[:, :, 0]
loss = self.model(wav, wavs_lens_rate, target, target_lens_rate)
if paddle.isfinite(loss):
@ -134,7 +135,8 @@ class Wav2Vec2ASRTrainer(Trainer):
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata:
msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
msg += "batch: {}/{}, ".format(i + 1,
len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items())
logger.info(msg)
@ -155,7 +157,8 @@ class Wav2Vec2ASRTrainer(Trainer):
self.before_train()
if not self.use_streamdata:
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
logger.info(
f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"):
self.model.train()
@ -223,14 +226,18 @@ class Wav2Vec2ASRTrainer(Trainer):
config = self.config.clone()
self.use_streamdata = config.get("use_stream_data", False)
if self.train:
self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args)
self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args)
self.train_loader = DataLoaderFactory.get_dataloader(
'train', config, self.args)
self.valid_loader = DataLoaderFactory.get_dataloader(
'valid', config, self.args)
logger.info("Setup train/valid Dataloader!")
else:
decode_batch_size = config.get('decode', dict()).get(
'decode_batch_size', 1)
self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args)
self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args)
self.test_loader = DataLoaderFactory.get_dataloader('test', config,
self.args)
self.align_loader = DataLoaderFactory.get_dataloader(
'align', config, self.args)
logger.info("Setup test/align Dataloader!")
def setup_model(self):
@ -312,14 +319,14 @@ class Wav2Vec2ASRTester(Wav2Vec2ASRTrainer):
self.text_featurizer = TextFeaturizer(
unit_type=config.unit_type, vocab=config.vocab_filepath)
self.vocab_list = self.text_featurizer.vocab_list
def id2token(self, texts, texts_len):
""" ord() id to chr() chr """
trans = []
for text, n in zip(texts, texts_len):
n = n.numpy().item()
ids = text[:n]
trans.append(
self.text_featurizer.defeaturize(ids.numpy().tolist()))
trans.append(self.text_featurizer.defeaturize(ids.numpy().tolist()))
return trans
def compute_metrics(self,
@ -337,10 +344,10 @@ class Wav2Vec2ASRTester(Wav2Vec2ASRTrainer):
start_time = time.time()
target_transcripts = self.id2token(texts, texts_len)
result_transcripts, result_tokenids = self.model.decode(
audio,
text_feature=self.text_featurizer,
decoding_method=decode_cfg.decoding_method,
beam_size=decode_cfg.beam_size)
audio,
text_feature=self.text_featurizer,
decoding_method=decode_cfg.decoding_method,
beam_size=decode_cfg.beam_size)
decode_time = time.time() - start_time
for utt, target, result, rec_tids in zip(

@ -3,6 +3,7 @@ Authors
* Elena Rastorgueva 2020
"""
import paddle
from paddlespeech.s2t.models.wav2vec2.modules import containers
from paddlespeech.s2t.models.wav2vec2.modules import linear
@ -27,12 +28,11 @@ class VanillaNN(containers.Sequential):
"""
def __init__(
self,
input_shape,
activation=paddle.nn.LeakyReLU,
dnn_blocks=2,
dnn_neurons=512,
):
self,
input_shape,
activation=paddle.nn.LeakyReLU,
dnn_blocks=2,
dnn_neurons=512, ):
super().__init__(input_shape=input_shape)
for block_index in range(dnn_blocks):
@ -40,6 +40,5 @@ class VanillaNN(containers.Sequential):
linear.Linear,
n_neurons=dnn_neurons,
bias=True,
layer_name="linear",
)
layer_name="linear", )
self.append(activation(), layer_name="act")

@ -11,12 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from packaging import version
from paddle import Tensor, nn
from paddle import nn
from paddle import Tensor
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
@ -29,7 +27,9 @@ class NewGELUActivation(nn.Layer):
"""
def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + paddle.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * paddle.pow(input, 3.0))))
return 0.5 * input * (1.0 + paddle.tanh(
math.sqrt(2.0 / math.pi) *
(input + 0.044715 * paddle.pow(input, 3.0))))
class GELUActivation(nn.Layer):
@ -40,7 +40,7 @@ class GELUActivation(nn.Layer):
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
def __init__(self, use_gelu_python: bool = False):
def __init__(self, use_gelu_python: bool=False):
super().__init__()
self.act = nn.functional.gelu
@ -57,7 +57,9 @@ class FastGELUActivation(nn.Layer):
"""
def forward(self, input: Tensor) -> Tensor:
return 0.5 * input * (1.0 + paddle.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
return 0.5 * input * (
1.0 + paddle.tanh(input * 0.7978845608 *
(1.0 + 0.044715 * input * input)))
class QuickGELUActivation(nn.Layer):
@ -84,7 +86,8 @@ class ClippedGELUActivation(nn.Layer):
def __init__(self, min: float, max: float):
if min > max:
raise ValueError(f"min should be < max (got min: {min}, max: {max})")
raise ValueError(
f"min should be < max (got min: {min}, max: {max})")
super().__init__()
self.min = min
@ -161,7 +164,9 @@ def get_activation(activation_string):
if activation_string in ACT2FN:
return ACT2FN[activation_string]
else:
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
raise KeyError(
f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}"
)
# For backwards compatibility with: from activations import gelu_python

@ -1,8 +1,7 @@
import paddle
import inspect
import logging
import operator
import functools
import paddle
class Sequential(paddle.nn.LayerDict):
"""A sequence of modules with potentially inferring shape on construction.
@ -98,13 +97,12 @@ class Sequential(paddle.nn.LayerDict):
# Finally, append the layer.
try:
self[layer_name] = layer
# self.add_module(layer_name, layer)
# self.add_module(layer_name, layer)
except TypeError:
raise ValueError(
"Must pass `input_shape` at initialization and use "
"modules that take `input_shape` to infer shape when "
"using `append()`."
)
"using `append()`.")
def get_output_shape(self):
"""Returns expected shape of the output.

@ -3,10 +3,10 @@ Authors
* Mirco Ravanelli 2020
* Davide Borra 2021
"""
import logging
import paddle
import paddle.nn as nn
from paddlespeech.s2t.modules import align
logger = logging.getLogger(__name__)
@ -37,13 +37,12 @@ class Linear(paddle.nn.Layer):
"""
def __init__(
self,
n_neurons,
input_shape=None,
input_size=None,
bias=True,
combine_dims=False,
):
self,
n_neurons,
input_shape=None,
input_size=None,
bias=True,
combine_dims=False, ):
super().__init__()
self.combine_dims = combine_dims

@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple
from collections import OrderedDict
from dataclasses import dataclass
from dataclasses import fields
from typing import Optional
from typing import Tuple
import paddle
@ -41,10 +41,13 @@ class ModelOutput(OrderedDict):
if not len(class_fields):
raise ValueError(f"{self.__class__.__name__} has no fields.")
if not all(field.default is None for field in class_fields[1:]):
raise ValueError(f"{self.__class__.__name__} should not have more than one required field.")
raise ValueError(
f"{self.__class__.__name__} should not have more than one required field."
)
first_field = getattr(self, class_fields[0].name)
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
other_fields_are_none = all(
getattr(self, field.name) is None for field in class_fields[1:])
if other_fields_are_none and not paddle.is_tensor(first_field):
if isinstance(first_field, dict):
@ -61,11 +64,9 @@ class ModelOutput(OrderedDict):
# set the associated fields
if first_field_iterator:
for element in iterator:
if (
not isinstance(element, (list, tuple))
or not len(element) == 2
or not isinstance(element[0], str)
):
if (not isinstance(element, (list, tuple)) or
not len(element) == 2 or
not isinstance(element[0], str)):
break
setattr(self, element[0], element[1])
if element[1] is not None:
@ -79,16 +80,23 @@ class ModelOutput(OrderedDict):
self[field.name] = v
def __delitem__(self, *args, **kwargs):
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
raise Exception(
f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance."
)
def setdefault(self, *args, **kwargs):
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
raise Exception(
f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance."
)
def pop(self, *args, **kwargs):
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
raise Exception(
f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
def update(self, *args, **kwargs):
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
raise Exception(
f"You cannot use ``update`` on a {self.__class__.__name__} instance."
)
def __getitem__(self, k):
if isinstance(k, str):

@ -13,24 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Paddle Wav2Vec2 model."""
import math
import warnings
import paddle
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from typing import Optional
from typing import Tuple
from typing import Union
import numpy as np
import paddle
from paddle import nn
from paddlespeech.s2t.models.wav2vec2.modules.activations import ACT2FN
from paddlespeech.s2t.models.wav2vec2.modules.modeling_outputs import (
BaseModelOutput,
Wav2Vec2BaseModelOutput,
ModelOutput
)
import pdb
from paddlespeech.s2t.models.wav2vec2.modules.modeling_outputs import BaseModelOutput
from paddlespeech.s2t.models.wav2vec2.modules.modeling_outputs import ModelOutput
from paddlespeech.s2t.models.wav2vec2.modules.modeling_outputs import Wav2Vec2BaseModelOutput
from paddlespeech.s2t.utils.log import Log
logger = Log(__name__).getlog()
@ -78,12 +73,11 @@ class Wav2Vec2ForPreTrainingOutput(ModelOutput):
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[paddle.Tensor] = None,
min_masks: int = 0,
) -> np.ndarray:
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[paddle.Tensor]=None,
min_masks: int=0, ) -> np.ndarray:
"""
Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
@ -109,8 +103,7 @@ def _compute_mask_indices(
if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
f" and `sequence_length`: {sequence_length}`"
)
f" and `sequence_length`: {sequence_length}`")
# epsilon is used for probabilistic rounding
epsilon = np.random.rand(1).item()
@ -131,11 +124,9 @@ def _compute_mask_indices(
return num_masked_span
# compute number of masked spans in batch
input_lengths = (
attention_mask.sum(-1).detach().tolist()
if attention_mask is not None
else [sequence_length for _ in range(batch_size)]
)
input_lengths = (attention_mask.sum(-1).detach().tolist()
if attention_mask is not None else
[sequence_length for _ in range(batch_size)])
# SpecAugment mask to fill
spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=np.bool)
@ -152,8 +143,9 @@ def _compute_mask_indices(
# get random indices to mask
spec_aug_mask_idx = np.random.choice(
np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
)
np.arange(input_length - (mask_length - 1)),
num_masked_span,
replace=False)
# pick first sampled index that will serve as a dummy index to pad vector
# to ensure same dimension for all batches due to probabilistic rounding
@ -166,29 +158,33 @@ def _compute_mask_indices(
else:
dummy_mask_idx = spec_aug_mask_idx[0]
spec_aug_mask_idx = np.concatenate(
[spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
)
spec_aug_mask_idx = np.concatenate([
spec_aug_mask_idx,
np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) *
dummy_mask_idx
])
spec_aug_mask_idxs.append(spec_aug_mask_idx)
spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
# expand masked indices to masked spans
spec_aug_mask_idxs = np.broadcast_to(
spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape((batch_size, max_num_masked_span * mask_length))
spec_aug_mask_idxs[:, :, None],
(batch_size, max_num_masked_span, mask_length))
spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(
(batch_size, max_num_masked_span * mask_length))
# add offset to the starting indexes so that indexes now create a span
offsets = np.arange(mask_length)[None, None, :]
offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
(batch_size, max_num_masked_span * mask_length)
)
offsets = np.broadcast_to(offsets, (
batch_size, max_num_masked_span, mask_length)).reshape(
(batch_size, max_num_masked_span * mask_length))
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
# ensure that we cannot have indices larger than sequence_length
if spec_aug_mask_idxs.max() > sequence_length - 1:
spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
spec_aug_mask_idxs[spec_aug_mask_idxs >
sequence_length - 1] = sequence_length - 1
# scatter indices to mask
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
@ -196,9 +192,9 @@ def _compute_mask_indices(
return spec_aug_mask
def _sample_negative_indices(
features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
):
def _sample_negative_indices(features_shape: Tuple,
num_negatives: int,
mask_time_indices: Optional[np.ndarray]=None):
"""
Sample `num_negatives` vectors from feature vectors.
"""
@ -208,23 +204,28 @@ def _sample_negative_indices(
sequence_length_range = np.arange(sequence_length)
# get `num_negatives` random vector indices from the same utterance
sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
sampled_negative_indices = np.zeros(
shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
mask_time_indices = (
mask_time_indices.astype(np.bool) if mask_time_indices is not None else np.ones(features_shape, dtype=np.bool)
)
mask_time_indices = (mask_time_indices.astype(np.bool)
if mask_time_indices is not None else
np.ones(features_shape, dtype=np.bool))
for batch_idx in range(batch_size):
high = mask_time_indices[batch_idx].sum() - 1
mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
mapped_masked_indices = sequence_length_range[mask_time_indices[
batch_idx]]
feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
feature_indices = np.broadcast_to(
np.arange(high + 1)[:, None], (high + 1, num_negatives))
sampled_indices = np.random.randint(
0, high, size=(high + 1, num_negatives))
# avoid sampling the same positive vector, but keep the distribution uniform
sampled_indices[sampled_indices >= feature_indices] += 1
# remap to actual indices
sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
sampled_negative_indices[batch_idx][mask_time_indices[
batch_idx]] = mapped_masked_indices[sampled_indices]
# correct for batch size
sampled_negative_indices[batch_idx] += batch_idx * sequence_length
@ -243,8 +244,7 @@ class Wav2Vec2NoLayerNormConvLayer(nn.Layer):
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias_attr=config.conv_bias,
)
bias_attr=config.conv_bias, )
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
@ -264,8 +264,7 @@ class Wav2Vec2LayerNormConvLayer(nn.Layer):
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias_attr=config.conv_bias,
)
bias_attr=config.conv_bias, )
self.layer_norm = nn.LayerNorm(self.out_conv_dim)
self.activation = ACT2FN[config.feat_extract_activation]
@ -290,11 +289,11 @@ class Wav2Vec2GroupNormConvLayer(nn.Layer):
self.out_conv_dim,
kernel_size=config.conv_kernel[layer_id],
stride=config.conv_stride[layer_id],
bias_attr=config.conv_bias,
)
bias_attr=config.conv_bias, )
self.activation = ACT2FN[config.feat_extract_activation]
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim)
self.layer_norm = nn.GroupNorm(
num_groups=self.out_conv_dim, num_channels=self.out_conv_dim)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
@ -311,8 +310,7 @@ class Wav2Vec2PositionalConvEmbedding(nn.Layer):
config.hidden_size,
kernel_size=config.num_conv_pos_embeddings,
padding=config.num_conv_pos_embeddings // 2,
groups=config.num_conv_pos_embedding_groups,
)
groups=config.num_conv_pos_embedding_groups, )
self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
@ -337,7 +335,7 @@ class Wav2Vec2SamePadLayer(nn.Layer):
def forward(self, hidden_states):
if self.num_pad_remove > 0:
hidden_states = hidden_states[:, :, : -self.num_pad_remove]
hidden_states = hidden_states[:, :, :-self.num_pad_remove]
return hidden_states
@ -349,11 +347,13 @@ class Wav2Vec2FeatureEncoder(nn.Layer):
if config.feat_extract_norm == "group":
conv_layers = [Wav2Vec2GroupNormConvLayer(config, layer_id=0)] + [
Wav2Vec2NoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
Wav2Vec2NoLayerNormConvLayer(config, layer_id=i + 1)
for i in range(config.num_feat_extract_layers - 1)
]
elif config.feat_extract_norm == "layer":
conv_layers = [
Wav2Vec2LayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
Wav2Vec2LayerNormConvLayer(config, layer_id=i)
for i in range(config.num_feat_extract_layers)
]
else:
raise ValueError(
@ -373,10 +373,12 @@ class Wav2Vec2FeatureEncoder(nn.Layer):
return hidden_states
class Wav2Vec2FeatureProjection(nn.Layer):
def __init__(self, config):
super().__init__()
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], epsilon=config.layer_norm_eps)
self.layer_norm = nn.LayerNorm(
config.conv_dim[-1], epsilon=config.layer_norm_eps)
self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
self.dropout = nn.Dropout(config.feat_proj_dropout)
@ -393,13 +395,12 @@ class Wav2Vec2Attention(nn.Layer):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
):
self,
embed_dim: int,
num_heads: int,
dropout: float=0.0,
is_decoder: bool=False,
bias: bool=True, ):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
@ -409,8 +410,7 @@ class Wav2Vec2Attention(nn.Layer):
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
f" and `num_heads`: {num_heads}).")
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
@ -420,17 +420,18 @@ class Wav2Vec2Attention(nn.Layer):
self.out_proj = nn.Linear(embed_dim, embed_dim, bias_attr=bias)
def _shape(self, tensor: paddle.Tensor, seq_len: int, bsz: int):
return paddle.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)).transpose([0, 2, 1, 3])
return paddle.reshape(tensor, (bsz, seq_len, self.num_heads,
self.head_dim)).transpose([0, 2, 1, 3])
def forward(
self,
hidden_states: paddle.Tensor,
key_value_states: Optional[paddle.Tensor] = None,
past_key_value: Optional[Tuple[paddle.Tensor]] = None,
attention_mask: Optional[paddle.Tensor] = None,
layer_head_mask: Optional[paddle.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
self,
hidden_states: paddle.Tensor,
key_value_states: Optional[paddle.Tensor]=None,
past_key_value: Optional[Tuple[paddle.Tensor]]=None,
attention_mask: Optional[paddle.Tensor]=None,
layer_head_mask: Optional[paddle.Tensor]=None,
output_attentions: bool=False, ) -> Tuple[paddle.Tensor, Optional[
paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
@ -455,7 +456,8 @@ class Wav2Vec2Attention(nn.Layer):
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = paddle.concat([past_key_value[0], key_states], axis=2)
value_states = paddle.concat([past_key_value[1], value_states], axis=2)
value_states = paddle.concat(
[past_key_value[1], value_states], axis=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
@ -472,60 +474,68 @@ class Wav2Vec2Attention(nn.Layer):
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).reshape(proj_shape)
query_states = self._shape(query_states, tgt_len,
bsz).reshape(proj_shape)
key_states = key_states.reshape(proj_shape)
value_states = value_states.reshape(proj_shape)
src_len = key_states.shape[1]
attn_weights = paddle.bmm(query_states, key_states.transpose([0, 2, 1]))
if attn_weights.shape != [bsz * self.num_heads, tgt_len, src_len]:
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.shape}"
)
f" {attn_weights.shape}")
if attention_mask is not None:
if attention_mask.shape != [bsz, 1, tgt_len, src_len]:
raise ValueError(
f"Attention mask should be of size {[bsz, 1, tgt_len, src_len]}, but is {attention_mask.shape}"
)
attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len,
src_len) + attention_mask
attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len,
src_len)
attn_weights = nn.functional.softmax(attn_weights, axis=- 1)
attn_weights = nn.functional.softmax(attn_weights, axis=-1)
if layer_head_mask is not None:
if layer_head_mask.shape != [self.num_heads,]:
if layer_head_mask.shape != [
self.num_heads,
]:
raise ValueError(
f"Head mask for a single layer should be of size {[self.num_heads,]}, but is"
f" {layer_head_mask.shape}"
)
attn_weights = layer_head_mask.reshape((1, -1, 1, 1)) * attn_weights.reshape((bsz, self.num_heads, tgt_len, src_len))
attn_weights = attn_weights.reshape((bsz * self.num_heads, tgt_len, src_len))
f" {layer_head_mask.shape}")
attn_weights = layer_head_mask.reshape(
(1, -1, 1, 1)) * attn_weights.reshape(
(bsz, self.num_heads, tgt_len, src_len))
attn_weights = attn_weights.reshape(
(bsz * self.num_heads, tgt_len, src_len))
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.reshape((bsz, self.num_heads, tgt_len, src_len))
attn_weights = attn_weights_reshaped.reshape((bsz * self.num_heads, tgt_len, src_len))
attn_weights_reshaped = attn_weights.reshape(
(bsz, self.num_heads, tgt_len, src_len))
attn_weights = attn_weights_reshaped.reshape(
(bsz * self.num_heads, tgt_len, src_len))
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_probs = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training)
attn_output = paddle.bmm(attn_probs, value_states)
if attn_output.shape != [bsz * self.num_heads, tgt_len, self.head_dim]:
raise ValueError(
f"`attn_output` should be of size {[bsz, self.num_heads, tgt_len, self.head_dim]}, but is"
f" {attn_output.shape}"
)
f" {attn_output.shape}")
attn_output = attn_output.reshape((bsz, self.num_heads, tgt_len, self.head_dim))
attn_output = attn_output.reshape(
(bsz, self.num_heads, tgt_len, self.head_dim))
attn_output = attn_output.transpose([0, 2, 1, 3])
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
@ -542,13 +552,15 @@ class Wav2Vec2FeedForward(nn.Layer):
super().__init__()
self.intermediate_dropout = nn.Dropout(config.activation_dropout)
self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_dense = nn.Linear(config.hidden_size,
config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.output_dense = nn.Linear(config.intermediate_size,
config.hidden_size)
self.output_dropout = nn.Dropout(config.hidden_dropout)
def forward(self, hidden_states):
@ -568,18 +580,23 @@ class Wav2Vec2EncoderLayer(nn.Layer):
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=False,
)
is_decoder=False, )
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)
self.layer_norm = nn.LayerNorm(
config.hidden_size, epsilon=config.layer_norm_eps)
self.feed_forward = Wav2Vec2FeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)
self.final_layer_norm = nn.LayerNorm(
config.hidden_size, epsilon=config.layer_norm_eps)
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
def forward(self,
hidden_states,
attention_mask=None,
output_attentions=False):
attn_residual = hidden_states
hidden_states, attn_weights, _ = self.attention(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions)
hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states
@ -587,10 +604,10 @@ class Wav2Vec2EncoderLayer(nn.Layer):
hidden_states = hidden_states + self.feed_forward(hidden_states)
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
outputs = (hidden_states, )
if output_attentions:
outputs += (attn_weights,)
outputs += (attn_weights, )
return outputs
@ -602,27 +619,33 @@ class Wav2Vec2EncoderLayerStableLayerNorm(nn.Layer):
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=False,
)
is_decoder=False, )
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)
self.layer_norm = nn.LayerNorm(
config.hidden_size, epsilon=config.layer_norm_eps)
self.feed_forward = Wav2Vec2FeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)
self.final_layer_norm = nn.LayerNorm(
config.hidden_size, epsilon=config.layer_norm_eps)
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
def forward(self,
hidden_states,
attention_mask=None,
output_attentions=False):
attn_residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
hidden_states, attn_weights, _ = self.attention(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions)
hidden_states = self.dropout(hidden_states)
hidden_states = attn_residual + hidden_states
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
hidden_states = hidden_states + self.feed_forward(
self.final_layer_norm(hidden_states))
outputs = (hidden_states,)
outputs = (hidden_states, )
if output_attentions:
outputs += (attn_weights,)
outputs += (attn_weights, )
return outputs
@ -632,33 +655,38 @@ class Wav2Vec2Encoder(nn.Layer):
super().__init__()
self.config = config
self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config)
self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)
self.layer_norm = nn.LayerNorm(
config.hidden_size, epsilon=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.LayerList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.layers = nn.LayerList([
Wav2Vec2EncoderLayer(config)
for _ in range(config.num_hidden_layers)
])
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
self,
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True, ):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
if attention_mask is not None:
# make sure padded tokens output 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(
1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = 1.0 - attention_mask[:, None, None, :].to(
dtype=hidden_states.dtype)
attention_mask = attention_mask * np.iinfo(np.float32).min
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
attention_mask = attention_mask.expand(attention_mask.shape[0], 1,
attention_mask.shape[-1],
attention_mask.shape[-1])
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
@ -669,13 +697,14 @@ class Wav2Vec2Encoder(nn.Layer):
for layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
all_hidden_states = all_hidden_states + (hidden_states, )
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = np.random.uniform(0, 1)
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer:# or deepspeed_zero3_is_enabled:
skip_the_layer = True if self.training and (
dropout_probability < self.config.layerdrop) else False
if not skip_the_layer: # or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
# create gradient checkpointing function
@ -686,26 +715,30 @@ class Wav2Vec2Encoder(nn.Layer):
return custom_forward
else:
layer_outputs = layer(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions)
hidden_states = layer_outputs[0]
if skip_the_layer:
layer_outputs = (None, None)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
all_self_attentions = all_self_attentions + (layer_outputs[1], )
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
all_hidden_states = all_hidden_states + (hidden_states, )
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return tuple(
v
for v in
[hidden_states, all_hidden_states, all_self_attentions]
if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
attentions=all_self_attentions, )
class Wav2Vec2EncoderStableLayerNorm(nn.Layer):
@ -713,35 +746,39 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Layer):
super().__init__()
self.config = config
self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config)
self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)
self.layer_norm = nn.LayerNorm(
config.hidden_size, epsilon=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.LayerList(
[Wav2Vec2EncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)]
)
self.layers = nn.LayerList([
Wav2Vec2EncoderLayerStableLayerNorm(config)
for _ in range(config.num_hidden_layers)
])
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
self,
hidden_states,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True, ):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
if attention_mask is not None:
# make sure padded tokens are not attended to
expand_attention_mask = attention_mask.unsqueeze(-1).repeat_interleave(hidden_states.shape[2], axis=2)
expand_attention_mask = attention_mask.unsqueeze(
-1).repeat_interleave(
hidden_states.shape[2], axis=2)
hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = 1.0 - attention_mask[:, None, None, :].to(
dtype=hidden_states.dtype)
attention_mask = attention_mask * np.iinfo(np.float32).min
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
attention_mask = attention_mask.expand(attention_mask.shape[0], 1,
attention_mask.shape[-1],
attention_mask.shape[-1])
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
@ -749,13 +786,14 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Layer):
for layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
all_hidden_states = all_hidden_states + (hidden_states, )
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = np.random.uniform(0, 1)
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer:# or deepspeed_zero3_is_enabled:
skip_the_layer = True if self.training and (
dropout_probability < self.config.layerdrop) else False
if not skip_the_layer: # or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if self.gradient_checkpointing and self.training:
@ -767,28 +805,32 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Layer):
return custom_forward
else:
layer_outputs = layer(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions)
hidden_states = layer_outputs[0]
if skip_the_layer:
layer_outputs = (None, None)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
all_self_attentions = all_self_attentions + (layer_outputs[1], )
hidden_states = self.layer_norm(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
all_hidden_states = all_hidden_states + (hidden_states, )
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return tuple(
v
for v in
[hidden_states, all_hidden_states, all_self_attentions]
if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
attentions=all_self_attentions, )
class Wav2Vec2GumbelVectorQuantizer(nn.Layer):
@ -810,9 +852,13 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Layer):
# storage for codebook variables (codewords)
self.codevectors = paddle.static.create_parameter(
shape=[1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups], dtype='float32'
)
self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
shape=[
1, self.num_groups * self.num_vars,
config.codevector_dim // self.num_groups
],
dtype='float32')
self.weight_proj = nn.Linear(config.conv_dim[-1],
self.num_groups * self.num_vars)
# can be decayed for training
self.temperature = 2
@ -826,7 +872,8 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Layer):
else:
marginal_probs = probs.mean(dim=0)
perplexity = paddle.exp(-paddle.sum(marginal_probs * paddle.log(marginal_probs + 1e-7), dim=-1)).sum()
perplexity = paddle.exp(-paddle.sum(
marginal_probs * paddle.log(marginal_probs + 1e-7), dim=-1)).sum()
return perplexity
def forward(self, hidden_states, mask_time_indices=None):
@ -834,35 +881,45 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Layer):
# project to codevector dim
hidden_states = self.weight_proj(hidden_states)
hidden_states = hidden_states.reshape((batch_size * sequence_length * self.num_groups, -1))
hidden_states = hidden_states.reshape(
(batch_size * sequence_length * self.num_groups, -1))
if self.training:
# sample code vector probs via gumbel in differentiateable way
codevector_probs = nn.functional.gumbel_softmax(
hidden_states.float(), tau=self.temperature, hard=True
).type_as(hidden_states)
hidden_states.float(), tau=self.temperature,
hard=True).type_as(hidden_states)
# compute perplexity
codevector_soft_dist = paddle.softmax(
hidden_states.reshape((batch_size * sequence_length, self.num_groups, -1)).float(), axis=-1
)
perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
hidden_states.reshape((batch_size * sequence_length,
self.num_groups, -1)).float(),
axis=-1)
perplexity = self._compute_perplexity(codevector_soft_dist,
mask_time_indices)
else:
# take argmax in non-differentiable way
# comptute hard codevector distribution (one hot)
codevector_idx = hidden_states.argmax(dim=-1)
codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
-1, codevector_idx.reshape((-1, 1)), 1.0
)
codevector_probs = codevector_probs.reshape((batch_size * sequence_length, self.num_groups, -1))
perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
codevector_probs = codevector_probs.reshape((batch_size * sequence_length, -1))
codevector_probs = hidden_states.new_zeros(
*hidden_states.shape).scatter_(-1,
codevector_idx.reshape((-1, 1)),
1.0)
codevector_probs = codevector_probs.reshape(
(batch_size * sequence_length, self.num_groups, -1))
perplexity = self._compute_perplexity(codevector_probs,
mask_time_indices)
codevector_probs = codevector_probs.reshape(
(batch_size * sequence_length, -1))
# use probs to retrieve codevectors
codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
codevectors = codevectors_per_group.reshape((batch_size * sequence_length, self.num_groups, self.num_vars, -1))
codevectors = codevectors.sum(-2).reshape((batch_size, sequence_length, -1))
codevectors_per_group = codevector_probs.unsqueeze(
-1) * self.codevectors
codevectors = codevectors_per_group.reshape(
(batch_size * sequence_length, self.num_groups, self.num_vars, -1))
codevectors = codevectors.sum(-2).reshape(
(batch_size, sequence_length, -1))
return codevectors, perplexity
@ -878,7 +935,9 @@ class Wav2Vec2Adapter(nn.Layer):
else:
self.proj = self.proj_layer_norm = None
self.layers = nn.LayerList(Wav2Vec2AdapterLayer(config) for _ in range(config.num_adapter_layers))
self.layers = nn.LayerList(
Wav2Vec2AdapterLayer(config)
for _ in range(config.num_adapter_layers))
self.layerdrop = config.layerdrop
def forward(self, hidden_states):
@ -906,8 +965,7 @@ class Wav2Vec2AdapterLayer(nn.Layer):
2 * config.output_hidden_size,
config.adapter_kernel_size,
stride=config.adapter_stride,
padding=1,
)
padding=1, )
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
@ -925,9 +983,13 @@ class Wav2Vec2Model(nn.Layer):
# model only needs masking vector if mask prob is > 0.0
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
# self.masked_spec_embed = nn.Parameter(paddle.Tensor(config.hidden_size).uniform_())
# self.masked_spec_embed = nn.Parameter(paddle.Tensor(config.hidden_size).uniform_())
#self.masked_spec_embed = paddle.uniform([config.hidden_size])
self.masked_spec_embed = paddle.static.create_parameter(shape=[config.hidden_size], dtype='float32', default_initializer=paddle.nn.initializer.Uniform(low=0, high=1.0))
self.masked_spec_embed = paddle.static.create_parameter(
shape=[config.hidden_size],
dtype='float32',
default_initializer=paddle.nn.initializer.Uniform(
low=0, high=1.0))
if config.do_stable_layer_norm:
self.encoder = Wav2Vec2EncoderStableLayerNorm(config)
else:
@ -946,11 +1008,10 @@ class Wav2Vec2Model(nn.Layer):
self.feature_extractor._freeze_parameters()
def _mask_hidden_states(
self,
hidden_states: paddle.Tensor,
mask_time_indices: Optional[paddle.Tensor] = None,
attention_mask: Optional[paddle.Tensor] = None,
):
self,
hidden_states: paddle.Tensor,
mask_time_indices: Optional[paddle.Tensor]=None,
attention_mask: Optional[paddle.Tensor]=None, ):
"""
Masks extracted features along time axis and/or along feature axis according to
[SpecAugment](https://arxiv.org/abs/1904.08779).
@ -963,17 +1024,19 @@ class Wav2Vec2Model(nn.Layer):
batch_size, sequence_length, hidden_size = hidden_states.shape
if mask_time_indices is not None:
# apply SpecAugment along time axis with given mask_time_indices
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
hidden_states[mask_time_indices] = self.masked_spec_embed.to(
hidden_states.dtype)
elif self.config.mask_time_prob > 0 and self.training:
mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
mask_prob=self.config.mask_time_prob,
mask_length=self.config.mask_time_length,
attention_mask=attention_mask,
min_masks=self.config.mask_time_min_masks,
)
mask_time_indices = paddle.to_tensor(mask_time_indices, dtype=paddle.bool)
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
min_masks=self.config.mask_time_min_masks, )
mask_time_indices = paddle.to_tensor(
mask_time_indices, dtype=paddle.bool)
hidden_states[mask_time_indices] = self.masked_spec_embed.to(
hidden_states.dtype)
if self.config.mask_feature_prob > 0 and self.training:
# generate indices & apply SpecAugment along feature axis
@ -981,27 +1044,28 @@ class Wav2Vec2Model(nn.Layer):
(batch_size, hidden_size),
mask_prob=self.config.mask_feature_prob,
mask_length=self.config.mask_feature_length,
min_masks=self.config.mask_feature_min_masks,
)
mask_feature_indices = paddle.to_tensor(mask_feature_indices, dtype=paddle.bool)
mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
min_masks=self.config.mask_feature_min_masks, )
mask_feature_indices = paddle.to_tensor(
mask_feature_indices, dtype=paddle.bool)
mask_feature_indices = mask_feature_indices[:, None].expand(
-1, sequence_length, -1)
hidden_states[mask_feature_indices] = 0
return hidden_states
def forward(
self,
input_values: Optional[paddle.Tensor],
attention_mask: Optional[paddle.Tensor] = None,
mask_time_indices: Optional[paddle.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
self,
input_values: Optional[paddle.Tensor],
attention_mask: Optional[paddle.Tensor]=None,
mask_time_indices: Optional[paddle.Tensor]=None,
output_attentions: Optional[bool]=None,
output_hidden_states: Optional[bool]=None,
return_dict: Optional[bool]=None,
) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else
self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
extract_features = self.feature_extractor(input_values)
extract_features = extract_features.transpose([0, 2, 1])
@ -1009,20 +1073,20 @@ class Wav2Vec2Model(nn.Layer):
if attention_mask is not None:
# compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(
extract_features.shape[1], attention_mask, add_adapter=False
)
hidden_states, extract_features = self.feature_projection(extract_features)
extract_features.shape[1], attention_mask, add_adapter=False)
hidden_states, extract_features = self.feature_projection(
extract_features)
hidden_states = self._mask_hidden_states(
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
)
hidden_states,
mask_time_indices=mask_time_indices,
attention_mask=attention_mask)
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
return_dict=return_dict, )
hidden_states = encoder_outputs[0]
@ -1036,20 +1100,21 @@ class Wav2Vec2Model(nn.Layer):
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
attentions=encoder_outputs.attentions, )
def post_init(self):
"""
A method executed at the end of each Transformer model initialization, to execute code that needs the model's
modules properly initialized (such as weight initialization).
"""
# self.init_weights()
# self._backward_compatibility_gradient_checkpointing()
# self.init_weights()
# self._backward_compatibility_gradient_checkpointing()
pass
class Wav2Vec2ConfigPure():
model_type = "wav2vec2"
def __init__(self, config):
self.output_attentions = False
self.output_hidden_states = False
@ -1084,17 +1149,14 @@ class Wav2Vec2ConfigPure():
self.do_stable_layer_norm = config.do_stable_layer_norm
self.use_weighted_layer_sum = config.use_weighted_layer_sum
if (
(len(self.conv_stride) != self.num_feat_extract_layers)
or (len(self.conv_kernel) != self.num_feat_extract_layers)
or (len(self.conv_dim) != self.num_feat_extract_layers)
):
if ((len(self.conv_stride) != self.num_feat_extract_layers) or
(len(self.conv_kernel) != self.num_feat_extract_layers) or
(len(self.conv_dim) != self.num_feat_extract_layers)):
raise ValueError(
"Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` =="
" `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) ="
f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`,"
f" `len(config.conv_kernel) = {len(self.conv_kernel)}`."
)
f" `len(config.conv_kernel) = {len(self.conv_kernel)}`.")
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
self.apply_spec_augment = config.apply_spec_augment

@ -7,10 +7,8 @@ Authors
* Samuele Cornell 2020
* Sarthak Yadav 2022
"""
import paddle
import math
from packaging import version
import numpy as np
import paddle
def blackman_window(window_length, periodic=True):
@ -90,15 +88,14 @@ def compute_amplitude(waveforms, lengths=None, amp_type="avg", scale="linear"):
def convolve1d(
waveform,
kernel,
padding=0,
pad_type="constant",
stride=1,
groups=1,
use_fft=False,
rotation_index=0,
):
waveform,
kernel,
padding=0,
pad_type="constant",
stride=1,
groups=1,
use_fft=False,
rotation_index=0, ):
"""Use paddle.nn.functional to perform 1d padding and conv.
Arguments
---------
@ -150,8 +147,7 @@ def convolve1d(
# Padding can be a tuple (left_pad, right_pad) or an int
if isinstance(padding, tuple):
waveform = paddle.nn.functional.pad(
x=waveform, pad=padding, mode=pad_type, data_format='NCL'
)
x=waveform, pad=padding, mode=pad_type, data_format='NCL')
# This approach uses FFT, which is more efficient if the kernel is large
if use_fft:
@ -165,9 +161,7 @@ def convolve1d(
# Perform rotation to ensure alignment
zeros = paddle.zeros(
[kernel.shape[0], kernel.shape[1], zero_length],
dtype=kernel.dtype
)
[kernel.shape[0], kernel.shape[1], zero_length], dtype=kernel.dtype)
after_index = kernel[..., rotation_index:]
before_index = kernel[..., :rotation_index]
kernel = paddle.concat((after_index, zeros, before_index), axis=-1)
@ -185,12 +179,12 @@ def convolve1d(
weight=kernel,
stride=stride,
groups=groups,
padding=padding if not isinstance(padding, tuple) else 0,
)
padding=padding if not isinstance(padding, tuple) else 0, )
# Return time dimension to the second dimension.
return convolved.transpose([0, 2, 1])
def notch_filter(notch_freq, filter_width=101, notch_width=0.05):
"""Returns a notch filter constructed from a high-pass and low-pass filter.
(from https://tomroelandts.com/articles/
@ -224,7 +218,8 @@ def notch_filter(notch_freq, filter_width=101, notch_width=0.05):
return paddle.sin(x) / x
# The zero is at the middle index
return paddle.concat([_sinc(x[:pad]), paddle.ones([1]), _sinc(x[pad + 1 :])])
return paddle.concat(
[_sinc(x[:pad]), paddle.ones([1]), _sinc(x[pad + 1:])])
# Compute a low-pass filter with cutoff frequency notch_freq.
hlpf = sinc(3 * (notch_freq - notch_width) * inputs)
@ -239,4 +234,3 @@ def notch_filter(notch_freq, filter_width=101, notch_width=0.05):
# Adding filters creates notch filter
return (hlpf + hhpf).view(1, -1, 1)

@ -1,11 +1,12 @@
import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddlespeech.s2t.models.wav2vec2.processing.signal_processing import (
compute_amplitude,
convolve1d,
notch_filter)
from paddlespeech.s2t.models.wav2vec2.processing.signal_processing import compute_amplitude
from paddlespeech.s2t.models.wav2vec2.processing.signal_processing import convolve1d
from paddlespeech.s2t.models.wav2vec2.processing.signal_processing import notch_filter
class SpeedPerturb(nn.Layer):
"""Slightly speed up or slow down an audio signal.
@ -36,8 +37,10 @@ class SpeedPerturb(nn.Layer):
"""
def __init__(
self, orig_freq, speeds=[90, 100, 110], perturb_prob=1.0,
):
self,
orig_freq,
speeds=[90, 100, 110],
perturb_prob=1.0, ):
super().__init__()
self.orig_freq = orig_freq
self.speeds = speeds
@ -73,11 +76,12 @@ class SpeedPerturb(nn.Layer):
return waveform.clone()
# Perform a random perturbation
self.samp_index = paddle.randint(len(self.speeds), shape=(1,))[0]
self.samp_index = paddle.randint(len(self.speeds), shape=(1, ))[0]
perturbed_waveform = self.resamplers[self.samp_index](waveform)
return perturbed_waveform
class Resample(nn.Layer):
"""This class resamples an audio signal using sinc-based interpolation.
@ -94,9 +98,12 @@ class Resample(nn.Layer):
Controls the sharpness of the filter, larger numbers result in a
sharper filter, but they are less efficient. Values from 4 to 10 are allowed.
"""
def __init__(
self, orig_freq=16000, new_freq=16000, lowpass_filter_width=6,
):
self,
orig_freq=16000,
new_freq=16000,
lowpass_filter_width=6, ):
super().__init__()
self.orig_freq = orig_freq
self.new_freq = new_freq
@ -193,8 +200,7 @@ class Resample(nn.Layer):
window_size = self.weights.shape[1]
tot_output_samp = self._output_samples(wave_len)
resampled_waveform = paddle.zeros(
(batch_size, num_channels, tot_output_samp)
)
(batch_size, num_channels, tot_output_samp))
# self.weights = self.weights.to(waveforms.device)
# Check weights are on correct device
@ -222,28 +228,25 @@ class Resample(nn.Layer):
right_padding = max(0, end_index + 1 - current_wave_len)
left_padding = max(0, -first_index)
wave_to_conv = paddle.nn.functional.pad(
wave_to_conv, (left_padding, right_padding), data_format='NCL'
)
wave_to_conv, (left_padding, right_padding), data_format='NCL')
conv_wave = paddle.nn.functional.conv1d(
x=wave_to_conv,
weight=self.weights[i].repeat(num_channels, 1, 1),
stride=self.conv_stride,
groups=num_channels,
)
groups=num_channels, )
# we want conv_wave[:, i] to be at
# output[:, i + n*conv_transpose_stride]
dilated_conv_wave = paddle.nn.functional.conv1d_transpose(
conv_wave, eye, stride=self.conv_transpose_stride
)
conv_wave, eye, stride=self.conv_transpose_stride)
# pad dilated_conv_wave so it reaches the output length if needed.
left_padding = i
previous_padding = left_padding + dilated_conv_wave.shape[-1]
right_padding = max(0, tot_output_samp - previous_padding)
dilated_conv_wave = paddle.nn.functional.pad(
dilated_conv_wave, (left_padding, right_padding), data_format='NCL'
)
dilated_conv_wave, (left_padding, right_padding),
data_format='NCL')
dilated_conv_wave = dilated_conv_wave[..., :tot_output_samp]
resampled_waveform += dilated_conv_wave
@ -326,9 +329,7 @@ class Resample(nn.Layer):
window_width = self.lowpass_filter_width / (2.0 * lowpass_cutoff)
assert lowpass_cutoff < min(self.orig_freq, self.new_freq) / 2
output_t = paddle.arange(
start=0.0, end=self.output_samples
)
output_t = paddle.arange(start=0.0, end=self.output_samples)
output_t /= self.new_freq
min_t = output_t - window_width
max_t = output_t + window_width
@ -346,23 +347,16 @@ class Resample(nn.Layer):
inside_window_indices = delta_t.abs() < (window_width)
# raised-cosine (Hanning) window with width `window_width`
weights[inside_window_indices] = 0.5 * (
1
+ paddle.cos(
2
* math.pi
* lowpass_cutoff
/ self.lowpass_filter_width
* delta_t[inside_window_indices]
)
)
weights[inside_window_indices] = 0.5 * (1 + paddle.cos(
2 * math.pi * lowpass_cutoff / self.lowpass_filter_width *
delta_t[inside_window_indices]))
t_eq_zero_indices = delta_t == 0.0
t_not_eq_zero_indices = ~t_eq_zero_indices
# sinc filter function
weights[t_not_eq_zero_indices] *= paddle.sin(
2 * math.pi * lowpass_cutoff * delta_t[t_not_eq_zero_indices]
) / (math.pi * delta_t[t_not_eq_zero_indices])
2 * math.pi * lowpass_cutoff * delta_t[t_not_eq_zero_indices]) / (
math.pi * delta_t[t_not_eq_zero_indices])
# limit of the function at t = 0
weights[t_eq_zero_indices] *= 2 * lowpass_cutoff
@ -405,14 +399,13 @@ class DropFreq(nn.Layer):
"""
def __init__(
self,
drop_freq_low=1e-14,
drop_freq_high=1,
drop_count_low=1,
drop_count_high=2,
drop_width=0.05,
drop_prob=1,
):
self,
drop_freq_low=1e-14,
drop_freq_high=1,
drop_count_low=1,
drop_count_high=2,
drop_width=0.05,
drop_prob=1, ):
super().__init__()
self.drop_freq_low = drop_freq_low
self.drop_freq_high = drop_freq_high
@ -443,14 +436,14 @@ class DropFreq(nn.Layer):
# Pick number of frequencies to drop
drop_count = paddle.randint(
low=self.drop_count_low, high=self.drop_count_high + 1, shape=(1,),
)
low=self.drop_count_low,
high=self.drop_count_high + 1,
shape=(1, ), )
# Pick a frequency to drop
drop_range = self.drop_freq_high - self.drop_freq_low
drop_frequency = (
paddle.rand(drop_count) * drop_range + self.drop_freq_low
)
paddle.rand(drop_count) * drop_range + self.drop_freq_low)
# Filter parameters
filter_length = 101
pad = filter_length // 2
@ -461,8 +454,9 @@ class DropFreq(nn.Layer):
# Subtract each frequency
for frequency in drop_frequency:
notch_kernel = notch_filter(
frequency, filter_length, self.drop_width,
)
frequency,
filter_length,
self.drop_width, )
drop_filter = convolve1d(drop_filter, notch_kernel, pad)
# Apply filter
@ -471,6 +465,7 @@ class DropFreq(nn.Layer):
# Remove channels dimension if added
return dropped_waveform.squeeze(-1)
class DropChunk(nn.Layer):
"""This class drops portions of the input signal.
Using `DropChunk` as an augmentation strategy helps a models learn to rely
@ -515,16 +510,15 @@ class DropChunk(nn.Layer):
"""
def __init__(
self,
drop_length_low=100,
drop_length_high=1000,
drop_count_low=1,
drop_count_high=10,
drop_start=0,
drop_end=None,
drop_prob=1,
noise_factor=0.0,
):
self,
drop_length_low=100,
drop_length_high=1000,
drop_count_low=1,
drop_count_high=10,
drop_start=0,
drop_end=None,
drop_prob=1,
noise_factor=0.0, ):
super().__init__()
self.drop_length_low = drop_length_low
self.drop_length_high = drop_length_high
@ -580,8 +574,7 @@ class DropChunk(nn.Layer):
drop_times = paddle.randint(
low=self.drop_count_low,
high=self.drop_count_high + 1,
shape=(batch_size,),
)
shape=(batch_size, ), )
# Iterate batch to set mask
for i in range(batch_size):
@ -592,8 +585,7 @@ class DropChunk(nn.Layer):
length = paddle.randint(
low=self.drop_length_low,
high=self.drop_length_high + 1,
shape=(drop_times[i],),
)
shape=(drop_times[i], ), )
# Compute range of starting locations
start_min = self.drop_start
@ -608,15 +600,16 @@ class DropChunk(nn.Layer):
# Pick starting locations
start = paddle.randint(
low=start_min, high=start_max + 1, shape=(drop_times[i],),
)
low=start_min,
high=start_max + 1,
shape=(drop_times[i], ), )
end = start + length
# Update waveform
if not self.noise_factor:
for j in range(drop_times[i]):
dropped_waveform[i, start[j] : end[j]] = 0.0
dropped_waveform[i, start[j]:end[j]] = 0.0
else:
# Uniform distribution of -2 to +2 * avg amplitude should
# preserve the average for normalization
@ -625,7 +618,7 @@ class DropChunk(nn.Layer):
# zero-center the noise distribution
noise_vec = paddle.rand([length[j]])
noise_vec = 2 * noise_max * noise_vec - noise_max
dropped_waveform[i, start[j] : end[j]] = noise_vec
dropped_waveform[i, start[j]:end[j]] = noise_vec
return dropped_waveform
@ -679,37 +672,33 @@ class TimeDomainSpecAugment(nn.Layer):
"""
def __init__(
self,
perturb_prob=1.0,
drop_freq_prob=1.0,
drop_chunk_prob=1.0,
speeds=[95, 100, 105],
sample_rate=16000,
drop_freq_count_low=0,
drop_freq_count_high=3,
drop_chunk_count_low=0,
drop_chunk_count_high=5,
drop_chunk_length_low=1000,
drop_chunk_length_high=2000,
drop_chunk_noise_factor=0,
):
self,
perturb_prob=1.0,
drop_freq_prob=1.0,
drop_chunk_prob=1.0,
speeds=[95, 100, 105],
sample_rate=16000,
drop_freq_count_low=0,
drop_freq_count_high=3,
drop_chunk_count_low=0,
drop_chunk_count_high=5,
drop_chunk_length_low=1000,
drop_chunk_length_high=2000,
drop_chunk_noise_factor=0, ):
super().__init__()
self.speed_perturb = SpeedPerturb(
perturb_prob=perturb_prob, orig_freq=sample_rate, speeds=speeds
)
perturb_prob=perturb_prob, orig_freq=sample_rate, speeds=speeds)
self.drop_freq = DropFreq(
drop_prob=drop_freq_prob,
drop_count_low=drop_freq_count_low,
drop_count_high=drop_freq_count_high,
)
drop_count_high=drop_freq_count_high, )
self.drop_chunk = DropChunk(
drop_prob=drop_chunk_prob,
drop_count_low=drop_chunk_count_low,
drop_count_high=drop_chunk_count_high,
drop_length_low=drop_chunk_length_low,
drop_length_high=drop_chunk_length_high,
noise_factor=drop_chunk_noise_factor,
)
noise_factor=drop_chunk_noise_factor, )
def forward(self, waveforms, lengths):
"""Returns the distorted waveforms.

@ -1,25 +1,19 @@
import numpy as np
import os
from collections import defaultdict
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2 import Wav2Vec2ConfigPure
from paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2 import Wav2Vec2Model
from paddlespeech.s2t.modules.mask import make_pad_mask
from paddlespeech.s2t.utils.utility import log_add
from collections import defaultdict
from paddlespeech.s2t.models.wav2vec2.modules.VanillaNN import VanillaNN
from paddlespeech.s2t.modules.ctc import CTCDecoderBase as CTC
from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank
from yacs.config import CfgNode
from paddlespeech.s2t.utils.utility import log_add
class Wav2vec2ASR(nn.Layer):
def __init__(self, config: dict):
@ -36,8 +30,16 @@ class Wav2vec2ASR(nn.Layer):
for parm in wav2vec2.parameters():
parm.trainable = False
self.wav2vec2 = wav2vec2
self.enc = VanillaNN(input_shape=[None,None,wav2vec2_config.hidden_size], activation=nn.LeakyReLU, dnn_blocks=config.dnn_blocks, dnn_neurons=config.dnn_neurons)
self.ctc = CTC(odim=config.output_dim, enc_n_units=config.dnn_neurons, blank_id=config.blank_id, dropout_rate=config.ctc_dropout_rate, reduction=True)
self.enc = VanillaNN(
input_shape=[None, None, wav2vec2_config.hidden_size],
activation=nn.LeakyReLU,
dnn_blocks=config.dnn_blocks,
dnn_neurons=config.dnn_neurons)
self.ctc = CTC(odim=config.output_dim,
enc_n_units=config.dnn_neurons,
blank_id=config.blank_id,
dropout_rate=config.ctc_dropout_rate,
reduction=True)
def forward(self, wav, wavs_lens_rate, target, target_lens_rate):
if self.normalize_wav:
@ -51,7 +53,8 @@ class Wav2vec2ASR(nn.Layer):
x = self.enc(feats)
x_lens = (wavs_lens_rate * x.shape[1]).round().astype(paddle.int64)
target_lens = (target_lens_rate * target.shape[1]).round().astype(paddle.int64)
target_lens = (target_lens_rate *
target.shape[1]).round().astype(paddle.int64)
ctc_loss = self.ctc(x, x_lens, target, target_lens)
return ctc_loss
@ -63,7 +66,8 @@ class Wav2vec2ASR(nn.Layer):
decoding_method: str,
beam_size: int):
batch_size = feats.shape[0]
if decoding_method is 'ctc_prefix_beam_search' and batch_size > 1:
if decoding_method == 'ctc_prefix_beam_search' and batch_size > 1:
logger.error(
f'decoding mode {decoding_method} must be running with batch_size == 1'
)
@ -79,13 +83,12 @@ class Wav2vec2ASR(nn.Layer):
# with other batch decoding mode
elif decoding_method == 'ctc_prefix_beam_search':
assert feats.shape[0] == 1
hyp = self.ctc_prefix_beam_search(
feats,
beam_size)
hyp = self.ctc_prefix_beam_search(feats, beam_size)
res = [text_feature.defeaturize(hyp)]
res_tokenids = [hyp]
else:
raise ValueError(f"wav2vec2 not support decoding method: {decoding_method}")
raise ValueError(
f"wav2vec2 not support decoding method: {decoding_method}")
return res, res_tokenids
@ -94,8 +97,7 @@ class Wav2vec2ASR(nn.Layer):
model = cls(config)
return model
def ctc_greedy_search(
self, wav) -> List[List[int]]:
def ctc_greedy_search(self, wav) -> List[List[int]]:
""" Apply CTC greedy search
Args:
speech (paddle.Tensor): (batch, max_len)
@ -104,7 +106,7 @@ class Wav2vec2ASR(nn.Layer):
List[List[int]]: best path result
"""
batch_size = wav.shape[0]
wav = wav[:,:,0]
wav = wav[:, :, 0]
if self.normalize_wav:
wav = F.layer_norm(wav, wav.shape[1:])
# Extract wav2vec output
@ -124,7 +126,10 @@ class Wav2vec2ASR(nn.Layer):
return hyps
def _ctc_prefix_beam_search(
self, wav, beam_size, blank_id: int=0, ) -> Tuple[List[Tuple[int, float]], paddle.Tensor]:
self,
wav,
beam_size,
blank_id: int=0, ) -> Tuple[List[Tuple[int, float]], paddle.Tensor]:
""" CTC prefix beam search inner implementation
Args:
speech (paddle.Tensor): (batch, max_len, feat_dim)
@ -142,7 +147,7 @@ class Wav2vec2ASR(nn.Layer):
paddle.Tensor: encoder output, (1, max_len, encoder_dim),
it will be used for rescoring in attention rescoring mode
"""
wav = wav[:,:,0]
wav = wav[:, :, 0]
if self.normalize_wav:
wav = F.layer_norm(wav, wav.shape[1:])
@ -219,29 +224,5 @@ class Wav2vec2ASR(nn.Layer):
Returns:
List[int]: CTC prefix beam search nbest results
"""
hyps = self._ctc_prefix_beam_search(
wav, beam_size)
hyps = self._ctc_prefix_beam_search(wav, beam_size)
return hyps[0][0]
# @jit.to_static
# def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
# """ Export interface for c++ call, apply linear transform and log
# softmax before ctc
# Args:
# xs (paddle.Tensor): encoder output, (B, T, D)
# Returns:
# paddle.Tensor: activation before ctc
# """
# return self.ctc.log_softmax(xs)
# def _get_data(self):
# data_dir = "data"
# wavs = np.load(os.path.join(data_dir, "wavs.npy"))
# wavs_lens = np.load(os.path.join(data_dir, "wavs_lens.npy"))
# tokens = np.load(os.path.join(data_dir, "tokens.npy"))
# tokens_lens = np.load(os.path.join(data_dir, "tokens_lens.npy"))
# batch = (paddle.to_tensor(wavs), paddle.to_tensor(wavs_lens, dtype='float32'),
# paddle.to_tensor(tokens, dtype='int32'), paddle.to_tensor(tokens_lens, dtype='float32'))
# return batch

Loading…
Cancel
Save