In [1]:
%cd /workspace/wenet/
%pwd

/workspace/wenet


'/workspace/wenet'

In [2]:

import argparse
import copy
import logging
import os

import torch
import torch.distributed as dist
import torch.optim as optim
import yaml
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader

from wenet.dataset.dataset import AudioDataset, CollateFunc
from wenet.transformer.asr_model import init_asr_model
from wenet.utils.checkpoint import load_checkpoint, save_checkpoint
from wenet.utils.executor import Executor
from wenet.utils.scheduler import WarmupLR

os.environ['CUDA_VISIBLE_DEVICES'] = "0"

In [3]:
parser = argparse.ArgumentParser(description='training your network')
parser.add_argument('--config', default="examples/aishell/s0/conf/train_conformer.yaml",  help='config file')
parser.add_argument('--train_data', default="examples/aishell/s0/raw_wav/train/format.data",  help='train data file')
parser.add_argument('--cv_data', default="examples/aishell/s0/raw_wav/dev/format.data",  help='cv data file')
parser.add_argument('--gpu',
                    type=int,
                    default=-1,
                    help='gpu id for this local rank, -1 for cpu')
parser.add_argument('--model_dir' , help='save model dir')
parser.add_argument('--checkpoint', help='checkpoint model')
parser.add_argument('--tensorboard_dir',
                    default='tensorboard',
                    help='tensorboard log dir')
parser.add_argument('--ddp.rank',
                    dest='rank',
                    default=0,
                    type=int,
                    help='global rank for distributed training')
parser.add_argument('--ddp.world_size',
                    dest='world_size',
                    default=-1,
                    type=int,
                    help='''number of total processes/gpus for
                    distributed training''')
parser.add_argument('--ddp.dist_backend',
                    dest='dist_backend',
                    default='nccl',
                    choices=['nccl', 'gloo'],
                    help='distributed backend')
parser.add_argument('--ddp.init_method',
                    dest='init_method',
                    default=None,
                    help='ddp init method')
parser.add_argument('--num_workers',
                    default=0,
                    type=int,
                    help='num of subprocess workers for reading')
parser.add_argument('--pin_memory',
                    action='store_true',
                    default=False,
                    help='Use pinned memory buffers used for reading')
parser.add_argument('--cmvn', default="examples/aishell/s0/raw_wav/train/global_cmvn", help='global cmvn file')

args = parser.parse_args([])
print(vars(args))

{'config': 'examples/aishell/s0/conf/train_conformer.yaml', 'train_data': 'examples/aishell/s0/raw_wav/train/format.data', 'cv_data': 'examples/aishell/s0/raw_wav/dev/format.data', 'gpu': -1, 'model_dir': None, 'checkpoint': None, 'tensorboard_dir': 'tensorboard', 'rank': 0, 'world_size': -1, 'dist_backend': 'nccl', 'init_method': None, 'num_workers': 0, 'pin_memory': False, 'cmvn': 'examples/aishell/s0/raw_wav/train/global_cmvn'}


In [4]:
# Set random seed
torch.manual_seed(777)
print(args)
with open(args.config, 'r') as fin:
    configs = yaml.load(fin, Loader=yaml.FullLoader)

Namespace(checkpoint=None, cmvn='examples/aishell/s0/raw_wav/train/global_cmvn', config='examples/aishell/s0/conf/train_conformer.yaml', cv_data='examples/aishell/s0/raw_wav/dev/format.data', dist_backend='nccl', gpu=-1, init_method=None, model_dir=None, num_workers=0, pin_memory=False, rank=0, tensorboard_dir='tensorboard', train_data='examples/aishell/s0/raw_wav/train/format.data', world_size=-1)


In [5]:
raw_wav = configs['raw_wav']

train_collate_func = CollateFunc(**configs['collate_conf'],
                                 raw_wav=raw_wav)

cv_collate_conf = copy.deepcopy(configs['collate_conf'])
# no augmenation on cv set
cv_collate_conf['spec_aug'] = False
cv_collate_conf['spec_sub'] = False
if raw_wav:
    cv_collate_conf['feature_dither'] = 0.0
    cv_collate_conf['speed_perturb'] = False
    cv_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0
cv_collate_func = CollateFunc(**cv_collate_conf, raw_wav=raw_wav)

dataset_conf = configs.get('dataset_conf', {})
train_dataset = AudioDataset(args.train_data,
                             **dataset_conf,
                             raw_wav=raw_wav)
cv_dataset = AudioDataset(args.cv_data, **dataset_conf, raw_wav=raw_wav)
# 120098 data/train/wav.scp
print(len(train_dataset), 'batches')
# 14326 data/dev/wav.scp
print(len(cv_dataset))

7507 batches
896


In [6]:
train_sampler = None
cv_sampler = None
train_data_loader = DataLoader(train_dataset,
                               collate_fn=train_collate_func,
                               sampler=train_sampler,
                               #shuffle=(train_sampler is None),
                               shuffle=False,
                               pin_memory=args.pin_memory,
                               batch_size=1,
                               num_workers=args.num_workers)
cv_data_loader = DataLoader(cv_dataset,
                            collate_fn=cv_collate_func,
                            sampler=cv_sampler,
                            shuffle=False,
                            batch_size=1,
                            pin_memory=args.pin_memory,
                            num_workers=args.num_workers)
print(len(cv_data_loader))

896


In [7]:
if raw_wav:
    input_dim = configs['collate_conf']['feature_extraction_conf'][
        'mel_bins']
else:
    input_dim = train_dataset.input_dim
vocab_size = train_dataset.output_dim
print(vocab_size, 'vocab')
print(input_dim , 'feat dim')

4233 vocab
80 feat dim


In [8]:
# Save configs to model_dir/train.yaml for inference and export
configs['input_dim'] = input_dim
configs['output_dim'] = vocab_size
configs['cmvn_file'] = args.cmvn
configs['is_json_cmvn'] = raw_wav
print(args.cmvn)


examples/aishell/s0/raw_wav/train/global_cmvn


In [9]:
import json
import math
import numpy as np
def _load_json_cmvn(json_cmvn_file):
    """ Load the json format cmvn stats file and calculate cmvn

    Args:
        json_cmvn_file: cmvn stats file in json format

    Returns:
        a numpy array of [means, vars]
    """
    with open(json_cmvn_file) as f:
        cmvn_stats = json.load(f)

    means = cmvn_stats['mean_stat']
    variance = cmvn_stats['var_stat']
    count = cmvn_stats['frame_num']
    for i in range(len(means)):
        means[i] /= count
        variance[i] = variance[i] / count - means[i] * means[i]
        if variance[i] < 1.0e-20:
            variance[i] = 1.0e-20
        variance[i] = 1.0 / math.sqrt(variance[i])
    cmvn = np.array([means, variance])
    return cmvn


def _load_kaldi_cmvn(kaldi_cmvn_file):
    """ Load the kaldi format cmvn stats file and calculate cmvn

    Args:
        kaldi_cmvn_file:  kaldi text style global cmvn file, which
           is generated by:
           compute-cmvn-stats --binary=false scp:feats.scp global_cmvn

    Returns:
        a numpy array of [means, vars]
    """
    means = []
    variance = []
    with open(kaldi_cmvn_file, 'r') as fid:
        # kaldi binary file start with '\0B'
        if fid.read(2) == '\0B':
            logger.error('kaldi cmvn binary file is not supported, please '
                         'recompute it by: compute-cmvn-stats --binary=false '
                         ' scp:feats.scp global_cmvn')
            sys.exit(1)
        fid.seek(0)
        arr = fid.read().split()
        assert (arr[0] == '[')
        assert (arr[-2] == '0')
        assert (arr[-1] == ']')
        feat_dim = int((len(arr) - 2 - 2) / 2)
        for i in range(1, feat_dim + 1):
            means.append(float(arr[i]))
        count = float(arr[feat_dim + 1])
        for i in range(feat_dim + 2, 2 * feat_dim + 2):
            variance.append(float(arr[i]))

    for i in range(len(means)):
        means[i] /= count
        variance[i] = variance[i] / count - means[i] * means[i]
        if variance[i] < 1.0e-20:
            variance[i] = 1.0e-20
        variance[i] = 1.0 / math.sqrt(variance[i])
    cmvn = np.array([means, variance])
    return cmvn


def _load_npz_cmvn(npz_cmvn_file, eps=1e-20):
    npzfile = np.load(npz_cmvn_file)
    means = npzfile["mean"]  #(1, D)
    std = npzfile["std"]  #(1, D)
    std = np.clip(std, eps, None)
    variance = 1.0 / std
    cmvn = np.array([means, variance])
    return cmvn


def load_cmvn(cmvn_file: str, filetype: str):
    """load cmvn from file.

    Args:
        cmvn_file (str): cmvn path.
        filetype (str): file type, optional[npz, json, kaldi].

    Raises:
        ValueError: file type not support.

    Returns:
        Tuple[np.ndarray, np.ndarray]: mean, istd
    """
    assert filetype in ['npz', 'json', 'kaldi'], filetype
    filetype = filetype.lower()
    if filetype == "json":
        cmvn = _load_json_cmvn(cmvn_file)
    elif filetype == "kaldi":
        cmvn = _load_kaldi_cmvn(cmvn_file)
    elif filetype == "npz":
        cmvn = _load_npz_cmvn(cmvn_file)
    else:
        raise ValueError(f"cmvn file type no support: {filetype}")
    return cmvn[0], cmvn[1]

mean, istd = load_cmvn(args.cmvn, 'json')
print(mean.shape)
print(istd.shape)
print(mean)
print(istd)

(80,)
(80,)
[ 9.87176362  9.93891555 10.23818678 10.85971412 11.68652649 12.2548801
 12.65768161 12.86138996 12.80733912 12.56625574 12.32007066 12.13879205
 12.31318868 12.55255216 12.61223855 12.56974526 12.38972728 12.14383338
 12.09285066 11.79395822 11.62259065 11.9263303  11.8154422  11.95122567
 11.83180553 11.88788759 11.79014437 11.88072035 11.90005711 11.97348142
 12.00982189 12.00881339 12.02619706 12.10479646 12.21555081 12.34399304
 12.45014401 12.4966879  12.48653775 12.3550783  12.39291732 12.2553737
 12.26496277 12.25314244 12.32545763 12.43359839 12.54867439 12.6763342
 12.80920698 12.92934681 12.96115138 12.96883353 12.99593057 13.04728142
 13.0588804  13.05737948 12.99921175 12.93402238 12.87429219 12.71652995
 12.48942004 12.27478385 12.26163069 12.28631891 12.31956049 12.4229073
 12.51480191 12.5785164  12.64719411 12.73762568 12.80017069 12.86872766
 12.96666856 13.06478583 13.15915908 13.27284306 13.31081821 13.23904279
 12.87936075 11.18310185]
[0.61219383 0.497

In [10]:
# Init asr model from configs
model = init_asr_model(configs)
print(model)

ASRModel(
  (encoder): ConformerEncoder(
    (global_cmvn): GlobalCMVN()
    (embed): Conv2dSubsampling4(
      (conv): Sequential(
        (0): Conv2d(1, 256, kernel_size=(3, 3), stride=(2, 2))
        (1): ReLU()
        (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2))
        (3): ReLU()
      )
      (out): Sequential(
        (0): Linear(in_features=4864, out_features=256, bias=True)
      )
      (pos_enc): RelPositionalEncoding(
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (after_norm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
    (encoders): ModuleList(
      (0): ConformerEncoderLayer(
        (self_attn): RelPositionMultiHeadedAttention(
          (linear_q): Linear(in_features=256, out_features=256, bias=True)
          (linear_k): Linear(in_features=256, out_features=256, bias=True)
          (linear_v): Linear(in_features=256, out_features=256, bias=True)
          (linear_out): Linear(in_features=256, out_features=256, bias=True)

In [11]:

def summary(layer, print_func=print):
    num_params = num_elements = 0
    for name, param in layer.state_dict().items():
        if print_func:
            print_func(
                "{} | {} | {}".format(name, param.shape, np.prod(param.shape)))
        num_elements += np.prod(param.shape)
        num_params += 1
    if print_func:
        print_func(
            f"Total parameters: {num_params}, {num_elements} elements."
        )
        
def print_params(model, print_func=print):
    if print_func is None:
        return
    total = 0.0
    num_params = 0.0
    for n, p in model.named_parameters():
        msg = f"{n} | {p.shape} | {np.prod(p.shape)} | {p.requires_grad}"
        total += np.prod(p.shape)
        num_params += 1
        if print_func:
            print_func(msg)
    if print_func:
        print_func(f"Total parameters: {num_params}, {total} elements.")

In [12]:
summary(model)

encoder.global_cmvn.mean | torch.Size([80]) | 80
encoder.global_cmvn.istd | torch.Size([80]) | 80
encoder.embed.conv.0.weight | torch.Size([256, 1, 3, 3]) | 2304
encoder.embed.conv.0.bias | torch.Size([256]) | 256
encoder.embed.conv.2.weight | torch.Size([256, 256, 3, 3]) | 589824
encoder.embed.conv.2.bias | torch.Size([256]) | 256
encoder.embed.out.0.weight | torch.Size([256, 4864]) | 1245184
encoder.embed.out.0.bias | torch.Size([256]) | 256
encoder.after_norm.weight | torch.Size([256]) | 256
encoder.after_norm.bias | torch.Size([256]) | 256
encoder.encoders.0.self_attn.pos_bias_u | torch.Size([4, 64]) | 256
encoder.encoders.0.self_attn.pos_bias_v | torch.Size([4, 64]) | 256
encoder.encoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536
encoder.encoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256
encoder.encoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536
encoder.encoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256
encoder.encoders.0.s

In [13]:
print_params(model)

encoder.embed.conv.0.weight | torch.Size([256, 1, 3, 3]) | 2304 | True
encoder.embed.conv.0.bias | torch.Size([256]) | 256 | True
encoder.embed.conv.2.weight | torch.Size([256, 256, 3, 3]) | 589824 | True
encoder.embed.conv.2.bias | torch.Size([256]) | 256 | True
encoder.embed.out.0.weight | torch.Size([256, 4864]) | 1245184 | True
encoder.embed.out.0.bias | torch.Size([256]) | 256 | True
encoder.after_norm.weight | torch.Size([256]) | 256 | True
encoder.after_norm.bias | torch.Size([256]) | 256 | True
encoder.encoders.0.self_attn.pos_bias_u | torch.Size([4, 64]) | 256 | True
encoder.encoders.0.self_attn.pos_bias_v | torch.Size([4, 64]) | 256 | True
encoder.encoders.0.self_attn.linear_q.weight | torch.Size([256, 256]) | 65536 | True
encoder.encoders.0.self_attn.linear_q.bias | torch.Size([256]) | 256 | True
encoder.encoders.0.self_attn.linear_k.weight | torch.Size([256, 256]) | 65536 | True
encoder.encoders.0.self_attn.linear_k.bias | torch.Size([256]) | 256 | True
encoder.encoders.0.s

In [14]:
for batch in cv_data_loader:
    keys, feat, text, feat_len, text_len = batch
    print(keys)
    print(feat.shape)
    print(feat)
    print(feat_len)
    print(text)
    print(text_len)
    np.savez('data.npz', keys=keys, feat=feat.numpy(), feat_len=feat_len.numpy(), text=text.numpy(), text_len=text_len.numpy())
    break

['BAC009S0739W0246', 'BAC009S0727W0424', 'BAC009S0753W0412', 'BAC009S0756W0206', 'BAC009S0740W0414', 'BAC009S0728W0426', 'BAC009S0739W0214', 'BAC009S0753W0423', 'BAC009S0734W0201', 'BAC009S0740W0427', 'BAC009S0730W0423', 'BAC009S0728W0367', 'BAC009S0730W0418', 'BAC009S0727W0157', 'BAC009S0749W0409', 'BAC009S0727W0418']
torch.Size([16, 207, 80])
tensor([[[ 8.9946,  9.5383,  9.1916,  ..., 10.5074,  9.5633,  8.2564],
         [ 9.7988, 10.4052,  9.2651,  ..., 10.2512,  9.5440,  8.8738],
         [10.6891, 10.3955,  8.0535,  ...,  9.9067, 10.0649,  8.0509],
         ...,
         [ 9.2180,  9.6507,  8.5053,  ...,  9.6872,  8.7425,  7.9865],
         [10.1291,  9.9352,  9.3798,  ...,  9.5639,  9.8260,  8.9795],
         [ 9.0955,  7.1338,  9.4680,  ...,  9.4727,  9.0212,  7.4479]],

        [[11.4310, 10.6719,  6.0841,  ...,  9.3827,  8.7297,  7.5316],
         [ 9.7317,  7.8105,  7.5715,  ..., 10.0430,  9.2436,  7.3541],
         [10.6502, 10.6006,  8.4678,  ...,  9.2814,  9.1869,  8.0703]

In [15]:
!ls
!cp data.npz /workspace/DeepSpeech-2.x/.notebook

CODE_OF_CONDUCT.md  data.npz  install.sh  README.md	    tools
CONTRIBUTING.md     docs      LICENSE	  requirements.txt  venv
CPPLINT.cfg	    examples  Makefile	  runtime	    wenet


In [16]:
model.cpu().eval()
total_loss, attention_loss, ctc_loss = model(feat, feat_len,
                                         text, text_len)
print(total_loss, attention_loss, ctc_loss )

tensor(111.9988)
tensor(830.9634, grad_fn=<SumBackward0>)
tensor([False, False, False, False, False,  True,  True, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
         True, False, False, False, False, False,  True,  True, False, False,
        False, False, False, False,  True, False, False, False, False, False,
        False, False, False, False, False, False,  True,  True,  True, False,
        False, False, False, False, False, False, False, False, False, False,
        False,  True,  True, False, False, False, False, False, False,  True,
        False, False, False, False, False, False,  True, False, False, False,
        False, False,  True,  True, False, False, False, False, False, False,
        False, False, False, False, False, False,  True,  True, False, False,
        False, False, False,  True,  True, False, False, False, False, False,
         True,  True])
tensor(669.4633, grad_fn=<SumBackward0>)
tensor(142.4888, gra

In [17]:
print(total_loss.device)

cpu


In [18]:
model.cuda().eval()
feat=feat.cuda()
feat_len=feat_len.cuda()
text=text.cuda()
text_len=text_len.cuda()

total_loss, attention_loss, ctc_loss = model(feat, feat_len,
                                         text, text_len)
print(total_loss.device)
print(total_loss.cpu().data.numpy(), attention_loss.cpu().data.numpy(), ctc_loss.cpu().data.numpy() )

tensor(112., device='cuda:0')
tensor(830.9634, device='cuda:0', grad_fn=<SumBackward0>)
tensor([False, False, False, False, False,  True,  True, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
         True, False, False, False, False, False,  True,  True, False, False,
        False, False, False, False,  True, False, False, False, False, False,
        False, False, False, False, False, False,  True,  True,  True, False,
        False, False, False, False, False, False, False, False, False, False,
        False,  True,  True, False, False, False, False, False, False,  True,
        False, False, False, False, False, False,  True, False, False, False,
        False, False,  True,  True, False, False, False, False, False, False,
        False, False, False, False, False, False,  True,  True, False, False,
        False, False, False,  True,  True, False, False, False, False, False,
         True,  True], device='cuda:0')
tensor(669.463

In [19]:
encoder_out, encoder_mask = model.encoder(feat, feat_len)
print(encoder_out.shape)
print(encoder_mask.shape)
print(encoder_out[0])

np.savez('/workspace/DeepSpeech-2.x/.notebook/encoder.npz',
         mask=encoder_mask.cpu().detach().numpy(),  
         out=encoder_out.cpu().detach().numpy())

torch.Size([16, 51, 256])
torch.Size([16, 1, 51])
tensor([[-0.7019,  0.5625,  0.6880,  ...,  1.1237,  0.7804,  1.1369],
        [-0.7788,  0.3913,  0.7189,  ...,  1.2519,  0.8862,  1.3173],
        [-0.9591,  0.6346,  0.8767,  ...,  0.9818,  0.7440,  1.2903],
        ...,
        [-1.0732,  0.6724,  0.9230,  ...,  0.9075,  0.8177,  1.3240],
        [-1.1654,  0.6820,  0.6939,  ...,  1.2238,  0.8028,  1.4507],
        [-1.2732,  0.7146,  0.7582,  ...,  0.9415,  0.8775,  1.2623]],
       device='cuda:0', grad_fn=<SelectBackward>)


In [20]:
# dump torch model to paddle
import numpy as np
state_dict = model.state_dict()
paddle_state_dict = {}

for n, p in state_dict.items():
    name_change=True

    if 'norm.running_mean' in n:
        new_n = n.replace('norm.running_', 'norm._')
    elif 'norm.running_var' in n:
        new_n = n.replace('norm.running_var', 'norm._variance')
    else:
        name_change=False
        new_n = n
    
    if name_change:
        print(f"{n} -> {new_n}")
        
    p = p.cpu().detach().numpy()
    if n.endswith('weight') and p.ndim == 2 and 'embed.0.weight' not in n:
        new_p = p.T
        print(f"{n}: {p.shape} -> {new_p.shape}")
    else:
        new_p = p
        
    if 'global_cmvn.mean' in n:
        print(p, p.dtype)
        
    paddle_state_dict[new_n] = new_p
    
np.savez('/workspace/DeepSpeech-2.x/.notebook/model',
       state=paddle_state_dict)

[ 9.871763   9.938915  10.238187  10.8597145 11.686526  12.25488
 12.657681  12.86139   12.807339  12.566256  12.32007   12.138792
 12.313189  12.552552  12.612239  12.569745  12.389728  12.143833
 12.092851  11.793959  11.622591  11.926331  11.815442  11.951225
 11.831805  11.887888  11.790144  11.88072   11.900057  11.973481
 12.009822  12.008814  12.026197  12.104796  12.21555   12.343993
 12.450144  12.496688  12.486538  12.355079  12.392918  12.255374
 12.264963  12.253142  12.325458  12.4335985 12.548675  12.676334
 12.809207  12.929347  12.961151  12.968834  12.995931  13.047281
 13.058881  13.05738   12.999211  12.934022  12.874292  12.71653
 12.48942   12.274784  12.261631  12.286319  12.31956   12.422907
 12.514802  12.578516  12.647194  12.737626  12.800171  12.868728
 12.966668  13.064786  13.159159  13.272843  13.310819  13.239043
 12.879361  11.183102 ] float32
encoder.embed.out.0.weight: (256, 4864) -> (4864, 256)
encoder.encoders.0.self_attn.linear_q.weight: (256, 256) 

In [21]:
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
loss_ctc = model.ctc(encoder_out, encoder_out_lens, text, text_len)
print(loss_ctc)
dir(loss_ctc)
loss_ctc.backward()
print(loss_ctc.grad)
#print(model.ctc.ctc_lo.weight.grad)
print(model.ctc.ctc_lo.weight.grad.T.cpu().data.numpy())
print(model.ctc.ctc_lo.bias.grad.cpu().data.numpy())

tensor(377.3326, device='cuda:0', grad_fn=<DivBackward0>)
None
[[ 3.16902351e+00 -1.51765049e-02  4.91097234e-02 ... -2.47973716e-03
  -5.93366381e-03 -7.26613170e-03]
 [-1.74185038e+00  7.75875803e-03 -4.49435972e-02 ...  9.92415240e-04
   2.46338220e-03  2.31891591e-03]
 [-2.33343077e+00  1.30476682e-02 -2.66557615e-02 ...  2.27533933e-03
   5.76929189e-03  7.48792710e-03]
 ...
 [-4.30356789e+00  2.46056803e-02 -9.00955945e-02 ...  4.43160534e-03
   1.16123557e-02  1.44716976e-02]
 [-3.36919212e+00  1.73155665e-02 -6.36875406e-02 ...  3.28367390e-03
   8.58021621e-03  1.07796099e-02]
 [-6.62039661e+00  3.49958315e-02 -1.23963736e-01 ...  6.36674836e-03
   1.60815325e-02  2.03892551e-02]]
[-4.3777566e+00  2.3245990e-02 -9.3339972e-02 ...  4.2569702e-03
  1.0920014e-02  1.3787906e-02]


  print(loss_ctc.grad)


In [24]:
loss_att, acc_att = model._calc_att_loss(encoder_out, encoder_mask,
                                                    text, text_len)
print(loss_att, acc_att)

tensor(112., device='cuda:0')
tensor(830.9634, device='cuda:0', grad_fn=<SumBackward0>)
tensor([False, False, False, False, False,  True,  True, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
         True, False, False, False, False, False,  True,  True, False, False,
        False, False, False, False,  True, False, False, False, False, False,
        False, False, False, False, False, False,  True,  True,  True, False,
        False, False, False, False, False, False, False, False, False, False,
        False,  True,  True, False, False, False, False, False, False,  True,
        False, False, False, False, False, False,  True, False, False, False,
        False, False,  True,  True, False, False, False, False, False, False,
        False, False, False, False, False, False,  True,  True, False, False,
        False, False, False,  True,  True, False, False, False, False, False,
         True,  True], device='cuda:0')
tensor(669.463

In [25]:
def pad_list(xs, pad_value: int):
    n_batch = len(xs)
    max_len = max([x.size(0) for x in xs])
    pad = torch.zeros(n_batch, max_len, dtype=xs[0].dtype, device=xs[0].device)
    pad = pad.fill_(pad_value)
    for i in range(n_batch):
        pad[i, :xs[i].size(0)] = xs[i]

    return pad

def add_sos_eos(ys_pad: torch.Tensor, sos: int, eos: int,
                ignore_id: int):

    _sos = torch.tensor([sos],
                        dtype=torch.long,
                        requires_grad=False,
                        device=ys_pad.device)
    _eos = torch.tensor([eos],
                        dtype=torch.long,
                        requires_grad=False,
                        device=ys_pad.device)
    ys = [y[y != ignore_id] for y in ys_pad]  # parse padded ys
    ys_in = [torch.cat([_sos, y], dim=0) for y in ys]
    ys_out = [torch.cat([y, _eos], dim=0) for y in ys]
    return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)

In [26]:
ys_pad = text
ys_pad_lens = text_len
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, model.sos, model.eos,
                                            model.ignore_id)
ys_in_lens = ys_pad_lens + 1
print(ys_in_pad)
print(ys_out_pad)

tensor([[4232, 2995, 3116, 1209,  565, 4232, 4232],
        [4232,  236, 1176,  331,   66, 3925, 4077],
        [4232, 2693,  524,  234, 1145,  366, 4232],
        [4232, 3875, 4211, 3062,  700, 4232, 4232],
        [4232,  272,  987, 1134,  494, 2959, 4232],
        [4232, 1936, 3715,  120, 2553, 2695, 2710],
        [4232,   25, 1149, 3930, 4232, 4232, 4232],
        [4232, 1753, 1778, 1237,  482, 3925,  110],
        [4232, 3703,    2,  565, 3827, 4232, 4232],
        [4232, 1150, 2734,   10, 2478, 3490, 4232],
        [4232,  426,  811,   95,  489,  144, 4232],
        [4232, 2313, 2006,  489,  975, 4232, 4232],
        [4232, 3702, 3414,  205, 1488, 2966, 1347],
        [4232,   70, 1741,  702, 1666, 4232, 4232],
        [4232,  703, 1778, 1030,  849, 4232, 4232],
        [4232,  814, 1674,  115, 3827, 4232, 4232]], device='cuda:0')
tensor([[2995, 3116, 1209,  565, 4232,   -1,   -1],
        [ 236, 1176,  331,   66, 3925, 4077, 4232],
        [2693,  524,  234, 1145,  366, 4232,  

In [27]:
decoder_out, _ = model.decoder(encoder_out, encoder_mask, ys_in_pad,
                                      ys_in_lens)
print(decoder_out.shape)
print(decoder_out[0])

torch.Size([16, 7, 4233])
tensor([[-3.7639e-01, -8.2272e-01,  7.4276e-01,  ...,  3.4201e-01,
          1.5035e-02,  4.0337e-01],
        [-8.7386e-01, -3.1389e-01,  4.1988e-01,  ...,  3.7724e-01,
         -1.4353e-01, -1.0024e+00],
        [-4.3505e-01,  3.4505e-02, -2.8710e-01,  ...,  7.7274e-02,
         -1.1672e+00, -2.6849e-01],
        ...,
        [ 4.2471e-01,  5.8886e-01,  2.0204e-02,  ...,  3.7405e-01,
          4.5470e-02, -3.7139e-01],
        [-3.7978e-01, -8.1084e-01,  7.5725e-01,  ...,  2.6039e-01,
         -7.9347e-04,  4.2538e-01],
        [-3.8280e-01, -8.1207e-01,  7.4943e-01,  ...,  2.6173e-01,
         -1.0499e-03,  4.2679e-01]], device='cuda:0', grad_fn=<SelectBackward>)


In [28]:
print(decoder_out.dtype)
print(ys_out_pad.dtype)
loss_att = model.criterion_att(decoder_out, ys_out_pad)
print(loss_att)
print(ys_out_pad)
print(decoder_out[0])
np.savez('/workspace/DeepSpeech-2.x/.notebook/decoder',
       decoder_out=decoder_out.cpu().detach().numpy())

torch.float32
torch.int64
tensor(112., device='cuda:0')
tensor(830.9634, device='cuda:0', grad_fn=<SumBackward0>)
tensor([False, False, False, False, False,  True,  True, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
         True, False, False, False, False, False,  True,  True, False, False,
        False, False, False, False,  True, False, False, False, False, False,
        False, False, False, False, False, False,  True,  True,  True, False,
        False, False, False, False, False, False, False, False, False, False,
        False,  True,  True, False, False, False, False, False, False,  True,
        False, False, False, False, False, False,  True, False, False, False,
        False, False,  True,  True, False, False, False, False, False, False,
        False, False, False, False, False, False,  True,  True, False, False,
        False, False, False,  True,  True, False, False, False, False, False,
         True,  True], devic

In [29]:
import torch
from torch import nn


class LabelSmoothingLoss(nn.Module):
    def __init__(self,
                 size: int,
                 padding_idx: int,
                 smoothing: float,
                 normalize_length: bool = False):
        """Construct an LabelSmoothingLoss object."""
        super(LabelSmoothingLoss, self).__init__()
        self.criterion = nn.KLDivLoss(reduction="none")
        self.padding_idx = padding_idx
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.size = size
        self.normalize_length = normalize_length

    def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """Compute loss between x and target.

        The model outputs and data labels tensors are flatten to
        (batch*seqlen, class) shape and a mask is applied to the
        padding part which should not be calculated for loss.

        Args:
            x (torch.Tensor): prediction (batch, seqlen, class)
            target (torch.Tensor):
                target signal masked with self.padding_id (batch, seqlen)
        Returns:
            loss (torch.Tensor) : The KL loss, scalar float value
        """
        assert x.size(2) == self.size
        batch_size = x.size(0)
        x = x.view(-1, self.size)
        target = target.view(-1)
        # use zeros_like instead of torch.no_grad() for true_dist,
        # since no_grad() can not be exported by JIT
        true_dist = torch.zeros_like(x)
        true_dist.fill_(self.smoothing / (self.size - 1))
        ignore = target == self.padding_idx  # (B,)
        print(self.smoothing / (self.size - 1))
        print(true_dist)
        total = len(target) - ignore.sum().item()
        target = target.masked_fill(ignore, 0)  # avoid -1 index
        true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
        print(true_dist.dtype)
        print(true_dist.square().sum())
        kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
        print(kl.sum())
        denom = total if self.normalize_length else batch_size
        print(ignore)
        numer= kl.masked_fill(ignore.unsqueeze(1), 0).sum()
        print(numer)
        return numer /denom

In [30]:
criteron = LabelSmoothingLoss(4233, -1, 0.1, False)
loss_att = criteron(decoder_out, ys_out_pad)
print(loss_att)
print(ys_out_pad.dtype)

2.3629489603024576e-05
tensor([[2.3629e-05, 2.3629e-05, 2.3629e-05,  ..., 2.3629e-05, 2.3629e-05,
         2.3629e-05],
        [2.3629e-05, 2.3629e-05, 2.3629e-05,  ..., 2.3629e-05, 2.3629e-05,
         2.3629e-05],
        [2.3629e-05, 2.3629e-05, 2.3629e-05,  ..., 2.3629e-05, 2.3629e-05,
         2.3629e-05],
        ...,
        [2.3629e-05, 2.3629e-05, 2.3629e-05,  ..., 2.3629e-05, 2.3629e-05,
         2.3629e-05],
        [2.3629e-05, 2.3629e-05, 2.3629e-05,  ..., 2.3629e-05, 2.3629e-05,
         2.3629e-05],
        [2.3629e-05, 2.3629e-05, 2.3629e-05,  ..., 2.3629e-05, 2.3629e-05,
         2.3629e-05]], device='cuda:0')
torch.float32
tensor(90.7203, device='cuda:0')
tensor(830.9634, device='cuda:0', grad_fn=<SumBackward0>)
tensor([False, False, False, False, False,  True,  True, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
         True, False, False, False, False, False,  True,  True, False, False,
        False, False, Fal

In [31]:
loss_att.backward()
print(loss_att.grad)
print(decoder_out.grad)

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

In [32]:
print(model.decoder.output_layer.bias.grad)
print(model.decoder.output_layer.weight.grad)

tensor([ 0.0024,  0.0019, -0.1098,  ...,  0.0028,  0.0020, -1.7978],
       device='cuda:0')
tensor([[ 6.5052e-04,  6.4419e-05, -6.1955e-06,  ...,  9.8220e-04,
         -2.5918e-05,  3.3754e-04],
        [ 3.9305e-04,  4.5799e-04,  1.4362e-04,  ...,  4.6800e-04,
          1.6911e-04,  2.7067e-04],
        [-1.3593e-01,  5.2201e-02,  3.2895e-02,  ...,  2.4580e-02,
          1.4590e-01, -4.6850e-02],
        ...,
        [ 1.0434e-03,  4.2251e-04,  6.5688e-04,  ...,  1.2144e-03,
          2.1159e-04,  6.6838e-04],
        [ 6.4997e-04,  4.4301e-04,  4.1550e-04,  ...,  1.0420e-03,
          2.4114e-04,  1.5338e-04],
        [-9.9337e-01,  5.4573e-01, -1.1371e-02,  ..., -4.3175e-01,
         -2.7850e-01, -4.4679e-01]], device='cuda:0')


In [42]:
xs = model.encoder.global_cmvn(feat)
print(xs)

tensor([[[-5.3698e-01, -1.9911e-01, -3.4997e-01,  ..., -8.2428e-01,
          -1.0265e+00, -9.6301e-01],
         [-4.4642e-02,  2.3176e-01, -3.2539e-01,  ..., -9.0159e-01,
          -1.0325e+00, -7.5987e-01],
         [ 5.0035e-01,  2.2691e-01, -7.3052e-01,  ..., -1.0055e+00,
          -8.7123e-01, -1.0306e+00],
         ...,
         [-4.0024e-01, -1.4325e-01, -5.7947e-01,  ..., -1.0718e+00,
          -1.2806e+00, -1.0518e+00],
         [ 1.5755e-01, -1.8495e-03, -2.8703e-01,  ..., -1.1090e+00,
          -9.4519e-01, -7.2506e-01],
         [-4.7520e-01, -1.3942e+00, -2.5754e-01,  ..., -1.1365e+00,
          -1.1943e+00, -1.2290e+00]],

        [[ 9.5454e-01,  3.6428e-01, -1.3891e+00,  ..., -1.1637e+00,
          -1.2845e+00, -1.2015e+00],
         [-8.5735e-02, -1.0579e+00, -8.9173e-01,  ..., -9.6441e-01,
          -1.1255e+00, -1.2599e+00],
         [ 4.7654e-01,  3.2887e-01, -5.9201e-01,  ..., -1.1942e+00,
          -1.1430e+00, -1.0242e+00],
         ...,
         [-4.7431e-01, -3

In [43]:
from wenet.utils.mask import make_pad_mask
masks = ~make_pad_mask(feat_len).unsqueeze(1)  # (B, 1, L)
print(masks)

tensor([[[ True,  True,  True,  ...,  True,  True,  True]],

        [[ True,  True,  True,  ...,  True,  True,  True]],

        [[ True,  True,  True,  ...,  True, False, False]],

        ...,

        [[ True,  True,  True,  ..., False, False, False]],

        [[ True,  True,  True,  ..., False, False, False]],

        [[ True,  True,  True,  ..., False, False, False]]], device='cuda:0')


In [44]:
xs, pos_emb, masks = model.encoder.embed(xs, masks)

In [45]:
print(xs)
print(pos_emb)
print(masks)
print(masks.shape)

tensor([[[-0.5482,  2.2866, -1.0750,  ...,  1.4504,  0.2895, -0.6945],
         [-0.8013,  1.7688, -1.6639,  ...,  1.8332,  0.6791, -0.2000],
         [-1.7112,  2.7057, -1.3363,  ...,  1.2336,  0.1870, -0.5735],
         ...,
         [-0.9697,  2.3129, -0.8752,  ...,  0.8584,  0.4853, -0.4177],
         [-1.3609,  2.1779, -1.7813,  ...,  2.0928,  0.2528, -0.3650],
         [-1.6967,  2.3544, -1.7417,  ...,  1.3670,  0.5951, -0.7415]],

        [[-1.9828,  2.3178, -0.9079,  ...,  0.4117,  0.5006,  0.0872],
         [-0.7640,  1.3558, -1.3613,  ...,  0.7317,  0.6784,  0.1685],
         [-0.9504,  1.6038, -1.3030,  ...,  0.5754,  0.2677,  0.3343],
         ...,
         [-1.4757,  2.5317, -1.2321,  ...,  1.2997,  0.5019, -0.1034],
         [-1.1731,  2.3172, -1.2542,  ...,  1.7391,  0.2171, -0.4445],
         [-1.2700,  3.2229, -0.8872,  ...,  1.6461,  0.0973, -0.7679]],

        [[-0.5873,  1.4291, -1.3950,  ...,  0.2102,  0.1027,  0.0918],
         [ 0.1743,  1.7834, -1.6422,  ...,  0

In [46]:
xs = model.encoder.global_cmvn(feat)
masks = ~make_pad_mask(feat_len).unsqueeze(1)  # (B, 1, L)
xs, pos_emb, masks = model.encoder.embed(xs, masks, offset=0)
print(xs.cpu().detach().numpy())

[[[-0.54822     2.2866027  -1.0750197  ...  1.4503604   0.28950194
   -0.6945408 ]
  [-0.8012542   1.7687558  -1.6638877  ...  1.833158    0.6791494
   -0.1999542 ]
  [-1.7112465   2.7057455  -1.3363413  ...  1.2336441   0.18697014
   -0.5735198 ]
  ...
  [-0.96968573  2.312949   -0.87524825 ...  0.85838526  0.4853347
   -0.41773027]
  [-1.3609431   2.1778803  -1.7812773  ...  2.0927877   0.25282228
   -0.36496443]
  [-1.6967483   2.3543842  -1.7416853  ...  1.366951    0.59511113
   -0.74147725]]

 [[-1.9828408   2.31777    -0.9078527  ...  0.41170627  0.5006162
    0.08721463]
  [-0.76404583  1.3557773  -1.3612567  ...  0.7317046   0.678426
    0.16851945]
  [-0.95044655  1.6037656  -1.3029968  ...  0.57544005  0.26769355
    0.33433008]
  ...
  [-1.475677    2.531713   -1.2320715  ...  1.2996731   0.50191855
   -0.10343577]
  [-1.1730809   2.3172235  -1.2542105  ...  1.7391105   0.21709818
   -0.44447583]
  [-1.2699623   3.2228963  -0.8871915  ...  1.6460502   0.09731755
   -0.76786

In [47]:
xs = model.encoder.global_cmvn(feat)
masks = ~make_pad_mask(feat_len).unsqueeze(1)  # (B, 1, L)

x = xs.unsqueeze(1)
x = model.encoder.embed.conv(x)
print(x)

tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]],

         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 2.0908e-03],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 1.1943e-02, 0.0000e+00,  ..., 0.0000

In [48]:
b, c, t, f = x.size()
x = model.encoder.embed.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
print(x.cpu().detach().numpy())
print(x.shape)

[[[-0.03426375  0.14291267 -0.06718873 ...  0.09064753  0.01809387
   -0.0434088 ]
  [-0.05007839  0.11054724 -0.10399298 ...  0.11457238  0.04244684
   -0.01249714]
  [-0.10695291  0.16910909 -0.08352133 ...  0.07710276  0.01168563
   -0.03584499]
  ...
  [-0.06060536  0.14455931 -0.05470302 ...  0.05364908  0.03033342
   -0.02610814]
  [-0.08505894  0.13611752 -0.11132983 ...  0.13079923  0.01580139
   -0.02281028]
  [-0.10604677  0.14714901 -0.10885533 ...  0.08543444  0.03719445
   -0.04634233]]

 [[-0.12392755  0.14486063 -0.05674079 ...  0.02573164  0.03128851
    0.00545091]
  [-0.04775286  0.08473608 -0.08507854 ...  0.04573154  0.04240163
    0.01053247]
  [-0.05940291  0.10023535 -0.0814373  ...  0.035965    0.01673085
    0.02089563]
  ...
  [-0.09222981  0.15823206 -0.07700447 ...  0.08122957  0.03136991
   -0.00646474]
  [-0.07331756  0.14482647 -0.07838815 ...  0.1086944   0.01356864
   -0.02777974]
  [-0.07937264  0.20143102 -0.05544947 ...  0.10287814  0.00608235
   -0.

In [49]:
x, pos_emb = model.encoder.embed.pos_enc(x, 0)
print(x.cpu().detach().numpy())

[[[-0.54822     2.2866027  -1.0750197  ...  1.4503604   0.28950194
   -0.6945408 ]
  [-0.8012542   1.7687558  -1.6638877  ...  1.833158    0.6791494
   -0.1999542 ]
  [-1.7112465   2.7057455  -1.3363413  ...  1.2336441   0.18697014
   -0.5735198 ]
  ...
  [-0.96968573  2.312949   -0.87524825 ...  0.85838526  0.4853347
   -0.41773027]
  [-1.3609431   2.1778803  -1.7812773  ...  2.0927877   0.25282228
   -0.36496443]
  [-1.6967483   2.3543842  -1.7416853  ...  1.366951    0.59511113
   -0.74147725]]

 [[-1.9828408   2.31777    -0.9078527  ...  0.41170627  0.5006162
    0.08721463]
  [-0.76404583  1.3557773  -1.3612567  ...  0.7317046   0.678426
    0.16851945]
  [-0.95044655  1.6037656  -1.3029968  ...  0.57544005  0.26769355
    0.33433008]
  ...
  [-1.475677    2.531713   -1.2320715  ...  1.2996731   0.50191855
   -0.10343577]
  [-1.1730809   2.3172235  -1.2542105  ...  1.7391105   0.21709818
   -0.44447583]
  [-1.2699623   3.2228963  -0.8871915  ...  1.6460502   0.09731755
   -0.76786

In [50]:
print(pos_emb.dtype)
print(pos_emb.cpu().detach().numpy())

torch.float32
[[[ 0.0000000e+00  1.0000000e+00  0.0000000e+00 ...  1.0000000e+00
    0.0000000e+00  1.0000000e+00]
  [ 8.4147096e-01  5.4030234e-01  8.0196178e-01 ...  1.0000000e+00
    1.0746076e-04  1.0000000e+00]
  [ 9.0929741e-01 -4.1614684e-01  9.5814437e-01 ...  1.0000000e+00
    2.1492151e-04  1.0000000e+00]
  ...
  [-7.6825464e-01 -6.4014435e-01  6.3279724e-01 ...  9.9998462e-01
    5.1580933e-03  9.9998671e-01]
  [-9.5375264e-01  3.0059254e-01  9.9899054e-01 ...  9.9998397e-01
    5.2655530e-03  9.9998611e-01]
  [-2.6237485e-01  9.6496606e-01  5.6074661e-01 ...  9.9998331e-01
    5.3730118e-03  9.9998558e-01]]]


In [54]:
def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,
                            use_dynamic_chunk: bool,
                            use_dynamic_left_chunk: bool,
                            decoding_chunk_size: int, static_chunk_size: int,
                            num_decoding_left_chunks: int):
    """ Apply optional mask for encoder.
    Args:
        xs (torch.Tensor): padded input, (B, L, D), L for max length
        mask (torch.Tensor): mask for xs, (B, 1, L)
        use_dynamic_chunk (bool): whether to use dynamic chunk or not
        use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
            training.
        decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
            0: default for training, use random dynamic chunk.
            <0: for decoding, use full chunk.
            >0: for decoding, use fixed chunk size as set.
        static_chunk_size (int): chunk size for static chunk training/decoding
            if it's greater than 0, if use_dynamic_chunk is true,
            this parameter will be ignored
        num_decoding_left_chunks: number of left chunks, this is for decoding,
            the chunk size is decoding_chunk_size.
            >=0: use num_decoding_left_chunks
            <0: use all left chunks
    Returns:
        torch.Tensor: chunk mask of the input xs.
    """
    # Whether to use chunk mask or not
    if use_dynamic_chunk:
        max_len = xs.size(1)
        if decoding_chunk_size < 0:
            chunk_size = max_len
            num_left_chunks = -1
        elif decoding_chunk_size > 0:
            chunk_size = decoding_chunk_size
            num_left_chunks = num_decoding_left_chunks
        else:
            # chunk size is either [1, 25] or full context(max_len).
            # Since we use 4 times subsampling and allow up to 1s(100 frames)
            # delay, the maximum frame is 100 / 4 = 25.
            chunk_size = torch.randint(1, max_len, (1, )).item()
            num_left_chunks = -1
            if chunk_size > max_len // 2:
                chunk_size = max_len
            else:
                chunk_size = chunk_size % 25 + 1
                if use_dynamic_left_chunk:
                    max_left_chunks = (max_len - 1) // chunk_size
                    num_left_chunks = torch.randint(0, max_left_chunks,
                                                    (1, )).item()
        chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
                                            num_left_chunks,
                                            xs.device)  # (L, L)
        chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
        chunk_masks = masks & chunk_masks  # (B, L, L)
    elif static_chunk_size > 0:
        num_left_chunks = num_decoding_left_chunks
        chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
                                            num_left_chunks,
                                            xs.device)  # (L, L)
        chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
        chunk_masks = masks & chunk_masks  # (B, L, L)
    else:
        chunk_masks = masks
    return chunk_masks

from wenet.utils.mask import make_pad_mask


masks = ~make_pad_mask(feat_len).unsqueeze(1)
xs = model.encoder.global_cmvn(feat)
xs, pos_emb, masks = model.encoder.embed(xs, masks, offset=0)

mask_pad = masks
decoding_chunk_size=0
num_decoding_left_chunks=-1
use_dynamic_left_chunk=-1
use_dynamic_chunk=False
static_chunk_size=-1
chunk_masks = add_optional_chunk_mask(
            xs, 
            masks, 
            use_dynamic_chunk,
            use_dynamic_left_chunk,
            decoding_chunk_size, 
            static_chunk_size,
            num_decoding_left_chunks)

np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_embed', 
         embed_out=xs.cpu().detach().numpy(), 
         pos_emb=pos_emb.cpu().detach().numpy(),
        chunk_masks=chunk_masks.cpu().detach().numpy(),
        mask_pad=mask_pad.cpu().detach().numpy())

model.eval()
# print(chunk_masks)
print(xs.shape)
for layer in model.encoder.encoders:
    #xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
    #np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0', enc_0=xs.cpu().detach().numpy())
    
    x = xs
    residual = x
    x_norm = layer.norm_ff_macaron(x)
    !rm /workspace/DeepSpeech-2.x/.notebook/enc_0_norm_ff.npz
    np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0_norm_ff', 
             norm_ff=x_norm.cpu().detach().numpy(),
            xs=xs.cpu().detach().numpy())
    #print(x.cpu().detach().numpy())
    for p in layer.norm_ff_macaron.parameters():
        #print(p, p.sum())
        pass
    
    x = residual + layer.ff_scale * layer.feed_forward_macaron(x_norm)
    
    ps = []
    for n, p in layer.feed_forward_macaron.state_dict().items():
      #print(n, p.cpu().data.numpy())
      ps.append(p.cpu().data.numpy())
      pass

    ff_l_x = layer.feed_forward_macaron.w_1(x_norm)
    ff_l_a_x = layer.feed_forward_macaron.activation(ff_l_x)
    ff_l_a_l_x = layer.feed_forward_macaron.w_2(ff_l_a_x)
    np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0_ff_out', 
             norm_ff=x_norm.cpu().detach().numpy(),
             ff_out=x.cpu().detach().numpy(),
             ff_l_x = ff_l_x.cpu().detach().numpy(),
             ff_l_a_x=ff_l_a_x.cpu().detach().numpy(),
             ff_l_a_l_x=ff_l_a_l_x.cpu().detach().numpy(),
             ps=ps,
            )
    
    
    residual = x
    x = layer.norm_mha(x)
    x_q = x
    
    x_att = layer.self_attn(x_q, x, x, pos_emb, masks)
    np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_0_selattn_out', 
             x_q=x_q.cpu().detach().numpy(),
             x=x.cpu().detach().numpy(),
             pos_emb = pos_emb.cpu().detach().numpy(),
             mask=mask.cpu().detach().numpy(),
             x_att=x_att.cpu().detach().numpy(),
            )
    
    break
#print(xs.cpu().detach().numpy())


i = 0
for layer in model.encoder.encoders:
    xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
    i += 1
    if i == 2:
        np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_2', enc_2=xs.cpu().detach().numpy())
        
np.savez('/workspace/DeepSpeech-2.x/.notebook/enc_all', enc_all=xs.cpu().detach().numpy())

torch.Size([16, 51, 256])


NameError: name 'mask' is not defined

In [None]:
out, mask = model.encoder(feat, feat_len)
#print(out.cpu().detach().numpy())