# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict from typing import Dict from typing import List from typing import Tuple import paddle import paddle.nn as nn import paddle.nn.functional as F from paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2 import Wav2Vec2ConfigPure from paddlespeech.s2t.models.wav2vec2.modules.modeling_wav2vec2 import Wav2Vec2Model from paddlespeech.s2t.models.wav2vec2.modules.VanillaNN import VanillaNN from paddlespeech.s2t.models.wav2vec2.processing.speech_augmentation import SpecAugment from paddlespeech.s2t.modules.ctc import CTCDecoderBase as CTC from paddlespeech.s2t.modules.initializer import DefaultInitializerContext from paddlespeech.s2t.utils.ctc_utils import remove_duplicates_and_blank from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.utility import log_add logger = Log(__name__).getlog() class Wav2vec2ASR(nn.Layer): def __init__(self, config: dict): super().__init__() init_type = config.get("init_type", None) with DefaultInitializerContext(init_type): self.config = config wav2vec2_config = Wav2Vec2ConfigPure(config) wav2vec2 = Wav2Vec2Model(wav2vec2_config) self.normalize_wav = config.normalize_wav self.output_norm = config.output_norm if hasattr(config, 'spec_augment'): self.spec_augment = SpecAugment(**config.spec_augment) if config.freeze_wav2vec2: wav2vec2.eval() for parm in wav2vec2.parameters(): parm.trainable = False self.wav2vec2 = wav2vec2 self.enc = VanillaNN(**config.enc) self.ctc = CTC(**config.ctc, odim=config.output_dim, batch_average=False, reduction='mean') def forward(self, wav, wavs_lens_rate, target, target_lens): if self.normalize_wav: wav = F.layer_norm(wav, wav.shape) # Extract wav2vec output out = self.wav2vec2(wav)[0] # We normalize the output if required if self.output_norm: out = F.layer_norm(out, out.shape) if self.training and hasattr(self.config, 'spec_augment'): feats = self.spec_augment(out) else: feats = out x = self.enc(feats) x_lens = (wavs_lens_rate * x.shape[1]).round().astype(paddle.int64) ctc_loss = self.ctc(x, x_lens, target, target_lens) return ctc_loss @paddle.no_grad() def decode(self, feats: paddle.Tensor, text_feature: Dict[str, int], decoding_method: str, beam_size: int, tokenizer: str=None, sb_pipeline=False): batch_size = feats.shape[0] if decoding_method == 'ctc_prefix_beam_search' and batch_size > 1: logger.error( f"decoding mode {decoding_method} must be running with batch_size == 1" ) logger.error(f"current batch_size is {batch_size}") if decoding_method == 'ctc_greedy_search': if tokenizer is None and sb_pipeline is False: hyps = self.ctc_greedy_search(feats) res = [text_feature.defeaturize(hyp) for hyp in hyps] res_tokenids = [hyp for hyp in hyps] else: if sb_pipeline is True: hyps = self.ctc_greedy_search(feats.unsqueeze(-1)) else: hyps = self.ctc_greedy_search(feats) res = [] res_tokenids = [] for sequence in hyps: # Decode token terms to words predicted_tokens = text_feature.convert_ids_to_tokens( sequence) tmp_res = [] tmp_res_tokenids = [] for c in predicted_tokens: if c == "[CLS]": continue elif c == "[SEP]" or c == "[PAD]": break else: tmp_res.append(c) tmp_res_tokenids.append(text_feature.vocab[c]) res.append(''.join(tmp_res)) res_tokenids.append(tmp_res_tokenids) # ctc_prefix_beam_search and attention_rescoring only return one # result in List[int], change it to List[List[int]] for compatible # with other batch decoding mode elif decoding_method == 'ctc_prefix_beam_search': assert feats.shape[0] == 1 if tokenizer is None and sb_pipeline is False: hyp = self.ctc_prefix_beam_search(feats, beam_size) res = [text_feature.defeaturize(hyp)] res_tokenids = [hyp] else: if sb_pipeline is True: hyp = self.ctc_prefix_beam_search( feats.unsqueeze(-1), beam_size) else: hyp = self.ctc_prefix_beam_search(feats, beam_size) res = [] res_tokenids = [] predicted_tokens = text_feature.convert_ids_to_tokens(hyp) tmp_res = [] tmp_res_tokenids = [] for c in predicted_tokens: if c == "[CLS]": continue elif c == "[SEP]" or c == "[PAD]": break else: tmp_res.append(c) tmp_res_tokenids.append(text_feature.vocab[c]) res.append(''.join(tmp_res)) res_tokenids.append(tmp_res_tokenids) else: raise ValueError( f"wav2vec2 not support decoding method: {decoding_method}") return res, res_tokenids @classmethod def from_config(cls, config): model = cls(config) return model def ctc_greedy_search(self, wav) -> List[List[int]]: """ Apply CTC greedy search Args: speech (paddle.Tensor): (batch, max_len) speech_length (paddle.Tensor): (batch, ) Returns: List[List[int]]: best path result """ batch_size = wav.shape[0] wav = wav[:, :, 0] if self.normalize_wav: wav = F.layer_norm(wav, wav.shape[1:]) # Extract wav2vec output out = self.wav2vec2(wav)[0] # We normalize the output if required if self.output_norm: out = F.layer_norm(out, out.shape[1:]) feats = out x = self.enc(feats) x_lens = x.shape[1] ctc_probs = self.ctc.log_softmax(x) # (B, maxlen, vocab_size) topk_prob, topk_index = ctc_probs.topk(1, axis=2) # (B, maxlen, 1) topk_index = topk_index.view(batch_size, x_lens) # (B, maxlen) hyps = [hyp.tolist() for hyp in topk_index] hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps] return hyps def _ctc_prefix_beam_search( self, wav, beam_size, blank_id: int=0, ) -> Tuple[List[Tuple[int, float]], paddle.Tensor]: """ CTC prefix beam search inner implementation Args: speech (paddle.Tensor): (batch, max_len, feat_dim) speech_length (paddle.Tensor): (batch, ) beam_size (int): beam size for beam search decoding_chunk_size (int): decoding chunk for dynamic chunk trained model. <0: for decoding, use full chunk. >0: for decoding, use fixed chunk size as set. 0: used for training, it's prohibited here simulate_streaming (bool): whether do encoder forward in a streaming fashion Returns: List[Tuple[int, float]]: nbest results, (N,1), (text, likelihood) paddle.Tensor: encoder output, (1, max_len, encoder_dim), it will be used for rescoring in attention rescoring mode """ wav = wav[:, :, 0] if self.normalize_wav: wav = F.layer_norm(wav, wav.shape[1:]) # Extract wav2vec output out = self.wav2vec2(wav)[0] # We normalize the output if required if self.output_norm: out = F.layer_norm(out, out.shape[1:]) feats = out x = self.enc(feats) maxlen = x.shape[1] ctc_probs = self.ctc.log_softmax(x) # (1, maxlen, vocab_size) ctc_probs = ctc_probs.squeeze(0) # cur_hyps: (prefix, (blank_ending_score, none_blank_ending_score)) # blank_ending_score and none_blank_ending_score in ln domain cur_hyps = [(tuple(), (0.0, -float('inf')))] # 2. CTC beam search step by step for t in range(0, maxlen): logp = ctc_probs[t] # (vocab_size,) # key: prefix, value (pb, pnb), default value(-inf, -inf) next_hyps = defaultdict(lambda: (-float('inf'), -float('inf'))) # 2.1 First beam prune: select topk best top_k_logp, top_k_index = logp.topk(beam_size) # (beam_size,) for s in top_k_index: s = s.item() ps = logp[s].item() for prefix, (pb, pnb) in cur_hyps: last = prefix[-1] if len(prefix) > 0 else None if s == blank_id: # blank n_pb, n_pnb = next_hyps[prefix] n_pb = log_add([n_pb, pb + ps, pnb + ps]) next_hyps[prefix] = (n_pb, n_pnb) elif s == last: # Update *ss -> *s; n_pb, n_pnb = next_hyps[prefix] n_pnb = log_add([n_pnb, pnb + ps]) next_hyps[prefix] = (n_pb, n_pnb) # Update *s-s -> *ss, - is for blank n_prefix = prefix + (s, ) n_pb, n_pnb = next_hyps[n_prefix] n_pnb = log_add([n_pnb, pb + ps]) next_hyps[n_prefix] = (n_pb, n_pnb) else: n_prefix = prefix + (s, ) n_pb, n_pnb = next_hyps[n_prefix] n_pnb = log_add([n_pnb, pb + ps, pnb + ps]) next_hyps[n_prefix] = (n_pb, n_pnb) # 2.2 Second beam prune next_hyps = sorted( next_hyps.items(), key=lambda x: log_add(list(x[1])), reverse=True) cur_hyps = next_hyps[:beam_size] hyps = [(y[0], log_add([y[1][0], y[1][1]])) for y in cur_hyps] return hyps def ctc_prefix_beam_search(self, wav, beam_size) -> List[int]: """ Apply CTC prefix beam search Args: speech (paddle.Tensor): (batch, max_len, feat_dim) speech_length (paddle.Tensor): (batch, ) beam_size (int): beam size for beam search decoding_chunk_size (int): decoding chunk for dynamic chunk trained model. <0: for decoding, use full chunk. >0: for decoding, use fixed chunk size as set. 0: used for training, it's prohibited here simulate_streaming (bool): whether do encoder forward in a streaming fashion Returns: List[int]: CTC prefix beam search nbest results """ hyps = self._ctc_prefix_beam_search(wav, beam_size) return hyps[0][0] class Wav2vec2Base(nn.Layer): """Wav2vec2 model""" def __init__(self, config: dict): super().__init__() wav2vec2_config = Wav2Vec2ConfigPure(config) wav2vec2 = Wav2Vec2Model(wav2vec2_config) self.wav2vec2 = wav2vec2 @classmethod def from_config(cls, configs: dict): """init model. Args: configs (dict): config dict. Raises: ValueError: raise when using not support encoder type. Returns: nn.Layer: Wav2Vec2Base """ model = cls(configs) return model def forward(self, wav): out = self.wav2vec2(wav) return out