diff --git a/examples/csmsc/tts2/README.md b/examples/csmsc/tts2/README.md
index e73f81fa9..de9e488c8 100644
--- a/examples/csmsc/tts2/README.md
+++ b/examples/csmsc/tts2/README.md
@@ -19,7 +19,7 @@ Run the command below to
 4. synthesize wavs.
     - synthesize waveform from `metadata.jsonl`.
     - synthesize waveform from text file.
-6. inference using static model.
+5. inference using static model.
 ```bash
 ./run.sh
 ```
diff --git a/examples/csmsc/tts3/README.md b/examples/csmsc/tts3/README.md
index 42f33faac..7eeb14fc5 100644
--- a/examples/csmsc/tts3/README.md
+++ b/examples/csmsc/tts3/README.md
@@ -19,6 +19,7 @@ Run the command below to
 4. synthesize wavs.
     - synthesize waveform from `metadata.jsonl`.
     - synthesize waveform from text file.
+5. inference using static model.
 ```bash
 ./run.sh
 ```
@@ -189,6 +190,13 @@ optional arguments:
 5. `--output-dir` is the directory to save synthesized audio files.
 6. `--device is` the type of device to run synthesis, 'cpu' and 'gpu' are supported. 'gpu' is recommended for faster synthesis.
 
+### Inference
+After Synthesize, we will get static models of fastspeech2 and pwgan in `${train_output_path}/inference`.
+`./local/inference.sh` calls `${BIN_DIR}/inference.py`, which provides a paddle static model inference example for fastspeech2 + pwgan synthesize.
+```bash
+CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path}
+```
+
 ## Pretrained Model
 Pretrained FastSpeech2 model with no silence in the edge of audios. [fastspeech2_nosil_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/fastspeech2_nosil_baker_ckpt_0.4.zip)
 
@@ -215,6 +223,7 @@ python3 ${BIN_DIR}/synthesize_e2e.py \
   --pwg-stat=pwg_baker_ckpt_0.4/pwg_stats.npy \
   --text=${BIN_DIR}/../sentences.txt \
   --output-dir=exp/default/test_e2e \
+  --inference-dir=exp/default/inference \
   --device="gpu" \
   --phones-dict=fastspeech2_nosil_baker_ckpt_0.4/phone_id_map.txt
 ```
diff --git a/examples/csmsc/tts3/local/inference.sh b/examples/csmsc/tts3/local/inference.sh
new file mode 100755
index 000000000..cab72547c
--- /dev/null
+++ b/examples/csmsc/tts3/local/inference.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+
+train_output_path=$1
+
+python3 ${BIN_DIR}/inference.py \
+  --inference-dir=${train_output_path}/inference \
+  --text=${BIN_DIR}/../sentences.txt \
+  --output-dir=${train_output_path}/pd_infer_out \
+  --phones-dict=dump/phone_id_map.txt
diff --git a/examples/csmsc/tts3/local/synthesize_e2e.sh b/examples/csmsc/tts3/local/synthesize_e2e.sh
index 8c9755dd0..b65427431 100755
--- a/examples/csmsc/tts3/local/synthesize_e2e.sh
+++ b/examples/csmsc/tts3/local/synthesize_e2e.sh
@@ -15,5 +15,6 @@ python3 ${BIN_DIR}/synthesize_e2e.py \
   --pwg-stat=pwg_baker_ckpt_0.4/pwg_stats.npy \
   --text=${BIN_DIR}/../sentences.txt \
   --output-dir=${train_output_path}/test_e2e \
+  --inference-dir=${train_output_path}/inference \
   --device="gpu" \
   --phones-dict=dump/phone_id_map.txt
diff --git a/examples/csmsc/tts3/run.sh b/examples/csmsc/tts3/run.sh
index f45ddab06..718d60760 100755
--- a/examples/csmsc/tts3/run.sh
+++ b/examples/csmsc/tts3/run.sh
@@ -35,3 +35,8 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
     # synthesize_e2e, vocoder is pwgan
     CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
 fi
+
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+    # inference with static model
+    CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path} || exit -1
+fi
diff --git a/parakeet/exps/fastspeech2/inference.py b/parakeet/exps/fastspeech2/inference.py
new file mode 100644
index 000000000..436760887
--- /dev/null
+++ b/parakeet/exps/fastspeech2/inference.py
@@ -0,0 +1,133 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import os
+from pathlib import Path
+
+import soundfile as sf
+from paddle import inference
+
+from parakeet.frontend.zh_frontend import Frontend
+
+
+def main():
+    parser = argparse.ArgumentParser(
+        description="Paddle Infernce with speedyspeech & parallel wavegan.")
+    parser.add_argument(
+        "--inference-dir", type=str, help="dir to save inference models")
+    parser.add_argument(
+        "--text",
+        type=str,
+        help="text to synthesize, a 'utt_id sentence' pair per line")
+    parser.add_argument("--output-dir", type=str, help="output dir")
+    parser.add_argument(
+        "--enable-auto-log", action="store_true", help="use auto log")
+    parser.add_argument(
+        "--phones-dict",
+        type=str,
+        default="phones.txt",
+        help="phone vocabulary file.")
+
+    args, _ = parser.parse_known_args()
+
+    frontend = Frontend(phone_vocab_path=args.phones_dict)
+    print("frontend done!")
+
+    fastspeech2_config = inference.Config(
+        str(Path(args.inference_dir) / "fastspeech2.pdmodel"),
+        str(Path(args.inference_dir) / "fastspeech2.pdiparams"))
+    fastspeech2_config.enable_use_gpu(50, 0)
+    # This line must be commented, if not, it will OOM
+    # fastspeech2_config.enable_memory_optim()
+    fastspeech2_predictor = inference.create_predictor(fastspeech2_config)
+
+    pwg_config = inference.Config(
+        str(Path(args.inference_dir) / "pwg.pdmodel"),
+        str(Path(args.inference_dir) / "pwg.pdiparams"))
+    pwg_config.enable_use_gpu(100, 0)
+    pwg_config.enable_memory_optim()
+    pwg_predictor = inference.create_predictor(pwg_config)
+
+    if args.enable_auto_log:
+        import auto_log
+        os.makedirs("output", exist_ok=True)
+        pid = os.getpid()
+        logger = auto_log.AutoLogger(
+            model_name="fastspeech2",
+            model_precision='float32',
+            batch_size=1,
+            data_shape="dynamic",
+            save_path="./output/auto_log.log",
+            inference_config=fastspeech2_config,
+            pids=pid,
+            process_name=None,
+            gpu_ids=0,
+            time_keys=['preprocess_time', 'inference_time', 'postprocess_time'],
+            warmup=0)
+
+    output_dir = Path(args.output_dir)
+    output_dir.mkdir(parents=True, exist_ok=True)
+    sentences = []
+
+    with open(args.text, 'rt') as f:
+        for line in f:
+            utt_id, sentence = line.strip().split()
+            sentences.append((utt_id, sentence))
+
+    for utt_id, sentence in sentences:
+        if args.enable_auto_log:
+            logger.times.start()
+        input_ids = frontend.get_input_ids(sentence, merge_sentences=True)
+        phone_ids = input_ids["phone_ids"]
+        phones = phone_ids[0].numpy()
+
+        if args.enable_auto_log:
+            logger.times.stamp()
+
+        input_names = fastspeech2_predictor.get_input_names()
+        phones_handle = fastspeech2_predictor.get_input_handle(input_names[0])
+
+        phones_handle.reshape(phones.shape)
+        phones_handle.copy_from_cpu(phones)
+
+        fastspeech2_predictor.run()
+        output_names = fastspeech2_predictor.get_output_names()
+        output_handle = fastspeech2_predictor.get_output_handle(output_names[0])
+        output_data = output_handle.copy_to_cpu()
+
+        input_names = pwg_predictor.get_input_names()
+        mel_handle = pwg_predictor.get_input_handle(input_names[0])
+        mel_handle.reshape(output_data.shape)
+        mel_handle.copy_from_cpu(output_data)
+
+        pwg_predictor.run()
+        output_names = pwg_predictor.get_output_names()
+        output_handle = pwg_predictor.get_output_handle(output_names[0])
+        wav = output_data = output_handle.copy_to_cpu()
+
+        if args.enable_auto_log:
+            logger.times.stamp()
+
+        sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=24000)
+
+        if args.enable_auto_log:
+            logger.times.end(stamp=True)
+        print(f"{utt_id} done!")
+
+    if args.enable_auto_log:
+        logger.report()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/parakeet/exps/fastspeech2/synthesize_e2e.py b/parakeet/exps/fastspeech2/synthesize_e2e.py
index dd1b57c8a..9c036e9fc 100644
--- a/parakeet/exps/fastspeech2/synthesize_e2e.py
+++ b/parakeet/exps/fastspeech2/synthesize_e2e.py
@@ -13,12 +13,15 @@
 # limitations under the License.
 import argparse
 import logging
+import os
 from pathlib import Path
 
 import numpy as np
 import paddle
 import soundfile as sf
 import yaml
+from paddle import jit
+from paddle.static import InputSpec
 from yacs.config import CfgNode
 
 from parakeet.frontend.zh_frontend import Frontend
@@ -74,7 +77,21 @@ def evaluate(args, fastspeech2_config, pwg_config):
     pwg_normalizer = ZScore(mu, std)
 
     fastspeech2_inference = FastSpeech2Inference(fastspeech2_normalizer, model)
+    fastspeech2_inference.eval()
+    fastspeech2_inference = jit.to_static(
+        fastspeech2_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
+    paddle.jit.save(fastspeech2_inference,
+                    os.path.join(args.inference_dir, "fastspeech2"))
+    fastspeech2_inference = paddle.jit.load(
+        os.path.join(args.inference_dir, "fastspeech2"))
     pwg_inference = PWGInference(pwg_normalizer, vocoder)
+    pwg_inference.eval()
+    pwg_inference = jit.to_static(
+        pwg_inference, input_spec=[
+            InputSpec([-1, 80], dtype=paddle.float32),
+        ])
+    paddle.jit.save(pwg_inference, os.path.join(args.inference_dir, "pwg"))
+    pwg_inference = paddle.jit.load(os.path.join(args.inference_dir, "pwg"))
 
     output_dir = Path(args.output_dir)
     output_dir.mkdir(parents=True, exist_ok=True)
@@ -135,6 +152,8 @@ def main():
         type=str,
         help="text to synthesize, a 'utt_id sentence' pair per line.")
     parser.add_argument("--output-dir", type=str, help="output dir.")
+    parser.add_argument(
+        "--inference-dir", type=str, help="dir to save inference models")
     parser.add_argument(
         "--device", type=str, default="gpu", help="device type to use.")
     parser.add_argument("--verbose", type=int, default=1, help="verbose.")
diff --git a/parakeet/models/fastspeech2/fastspeech2.py b/parakeet/models/fastspeech2/fastspeech2.py
index 7c0e20bc2..0dbbb7bd9 100644
--- a/parakeet/models/fastspeech2/fastspeech2.py
+++ b/parakeet/models/fastspeech2/fastspeech2.py
@@ -341,6 +341,7 @@ class FastSpeech2(nn.Layer):
         Tensor
             speech_lengths, modified if reduction_factor > 1
         """
+
         # input of embedding must be int64
         xs = paddle.cast(text, 'int64')
         ilens = paddle.cast(text_lengths, 'int64')
@@ -388,7 +389,6 @@ class FastSpeech2(nn.Layer):
                  tone_id=None) -> Sequence[paddle.Tensor]:
         # forward encoder
         x_masks = self._source_mask(ilens)
-
         # (B, Tmax, adim)
         hs, _ = self.encoder(xs, x_masks)
 
@@ -405,7 +405,6 @@ class FastSpeech2(nn.Layer):
             if tone_id is not None:
                 tone_embs = self.tone_embedding_table(tone_id)
                 hs = self._integrate_with_tone_embed(hs, tone_embs)
-
         # forward duration predictor and variance predictors
         d_masks = make_pad_mask(ilens)
 
@@ -428,6 +427,7 @@ class FastSpeech2(nn.Layer):
             e_embs = self.energy_embed(e_outs.transpose((0, 2, 1))).transpose(
                 (0, 2, 1))
             hs = hs + e_embs + p_embs
+
             # (B, Lmax, adim)
             hs = self.length_regulator(hs, d_outs, alpha)
         else:
@@ -438,6 +438,7 @@ class FastSpeech2(nn.Layer):
             e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose(
                 (0, 2, 1))
             hs = hs + e_embs + p_embs
+
             # (B, Lmax, adim)
             hs = self.length_regulator(hs, ds)
 
@@ -452,9 +453,11 @@ class FastSpeech2(nn.Layer):
         else:
             h_masks = None
         # (B, Lmax, adim)
+
         zs, _ = self.decoder(hs, h_masks)
         # (B, Lmax, odim)
-        before_outs = self.feat_out(zs).reshape((zs.shape[0], -1, self.odim))
+        before_outs = self.feat_out(zs).reshape(
+            (paddle.shape(zs)[0], -1, self.odim))
 
         # postnet -> (B, Lmax//r * r, odim)
         if self.postnet is None:
@@ -517,8 +520,8 @@ class FastSpeech2(nn.Layer):
             d = paddle.cast(durations, 'int64')
         p, e = pitch, energy
         # setup batch axis
-        ilens = paddle.to_tensor(
-            [x.shape[0]], dtype=paddle.int64, place=x.place)
+        ilens = paddle.shape(x)[0]
+
         xs, ys = x.unsqueeze(0), None
 
         if y is not None:
diff --git a/parakeet/modules/fastspeech2_predictor/length_regulator.py b/parakeet/modules/fastspeech2_predictor/length_regulator.py
index e5195e536..a4d508add 100644
--- a/parakeet/modules/fastspeech2_predictor/length_regulator.py
+++ b/parakeet/modules/fastspeech2_predictor/length_regulator.py
@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Length regulator related modules."""
-import numpy as np
 import paddle
 from paddle import nn
 
@@ -49,11 +48,10 @@ class LengthRegulator(nn.Layer):
         encodings: (B, T, C)
         durations: (B, T)
         """
-        batch_size, t_enc = durations.shape
-        durations = durations.numpy()
-        slens = np.sum(durations, -1)
-        t_dec = np.max(slens)
-        M = np.zeros([batch_size, t_dec, t_enc])
+        batch_size, t_enc = paddle.shape(durations)
+        slens = durations.sum(-1)
+        t_dec = slens.max()
+        M = paddle.zeros([batch_size, t_dec, t_enc])
         for i in range(batch_size):
             k = 0
             for j in range(t_enc):
@@ -61,7 +59,6 @@ class LengthRegulator(nn.Layer):
                 if d >= 1:
                     M[i, k:k + d, j] = 1
                 k += d
-        M = paddle.to_tensor(M, dtype=encodings.dtype)
         encodings = paddle.matmul(M, encodings)
         return encodings
 
@@ -82,6 +79,7 @@ class LengthRegulator(nn.Layer):
         Tensor
             replicated input tensor based on durations (B, T*, D).
         """
+
         if alpha != 1.0:
             assert alpha > 0
             ds = paddle.round(ds.cast(dtype=paddle.float32) * alpha)
diff --git a/parakeet/modules/fastspeech2_transformer/attention.py b/parakeet/modules/fastspeech2_transformer/attention.py
index ae941a79a..0bac47426 100644
--- a/parakeet/modules/fastspeech2_transformer/attention.py
+++ b/parakeet/modules/fastspeech2_transformer/attention.py
@@ -106,13 +106,11 @@ class MultiHeadedAttention(nn.Layer):
         n_batch = value.shape[0]
         softmax = paddle.nn.Softmax(axis=-1)
         if mask is not None:
-
             mask = mask.unsqueeze(1)
             mask = paddle.logical_not(mask)
-            min_value = float(
-                numpy.finfo(
-                    paddle.to_tensor(0, dtype=scores.dtype).numpy().dtype).min)
-
+            # assume scores.dtype==paddle.float32, we only use "float32" here
+            dtype = str(scores.dtype).split(".")[-1]
+            min_value = numpy.finfo(dtype).min
             scores = masked_fill(scores, mask, min_value)
             # (batch, head, time1, time2)
             self.attn = softmax(scores)
diff --git a/parakeet/modules/fastspeech2_transformer/embedding.py b/parakeet/modules/fastspeech2_transformer/embedding.py
index 6c1c7245f..1dfd6dfdc 100644
--- a/parakeet/modules/fastspeech2_transformer/embedding.py
+++ b/parakeet/modules/fastspeech2_transformer/embedding.py
@@ -31,9 +31,16 @@ class PositionalEncoding(nn.Layer):
         Maximum input length.
     reverse : bool
         Whether to reverse the input position.
+    type : str
+        dtype of param
     """
 
-    def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
+    def __init__(self,
+                 d_model,
+                 dropout_rate,
+                 max_len=5000,
+                 dtype="float32",
+                 reverse=False):
         """Construct an PositionalEncoding object."""
         super(PositionalEncoding, self).__init__()
         self.d_model = d_model
@@ -41,20 +48,21 @@ class PositionalEncoding(nn.Layer):
         self.xscale = math.sqrt(self.d_model)
         self.dropout = nn.Dropout(p=dropout_rate)
         self.pe = None
-        self.extend_pe(paddle.expand(paddle.to_tensor(0.0), (1, max_len)))
+        self.dtype = dtype
+        self.extend_pe(paddle.expand(paddle.zeros([1]), (1, max_len)))
 
     def extend_pe(self, x):
         """Reset the positional encodings."""
-
-        pe = paddle.zeros([x.shape[1], self.d_model])
+        x_shape = paddle.shape(x)
+        pe = paddle.zeros([x_shape[1], self.d_model])
         if self.reverse:
             position = paddle.arange(
-                x.shape[1] - 1, -1, -1.0, dtype=paddle.float32).unsqueeze(1)
+                x_shape[1] - 1, -1, -1.0, dtype=self.dtype).unsqueeze(1)
         else:
             position = paddle.arange(
-                0, x.shape[1], dtype=paddle.float32).unsqueeze(1)
+                0, x_shape[1], dtype=self.dtype).unsqueeze(1)
         div_term = paddle.exp(
-            paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
+            paddle.arange(0, self.d_model, 2, dtype=self.dtype) *
             -(math.log(10000.0) / self.d_model))
         pe[:, 0::2] = paddle.sin(position * div_term)
         pe[:, 1::2] = paddle.cos(position * div_term)
@@ -75,7 +83,8 @@ class PositionalEncoding(nn.Layer):
             Encoded tensor (batch, time, `*`).
         """
         self.extend_pe(x)
-        x = x * self.xscale + self.pe[:, :x.shape[1]]
+        T = paddle.shape(x)[1]
+        x = x * self.xscale + self.pe[:, :T]
         return self.dropout(x)
 
 
@@ -92,21 +101,26 @@ class ScaledPositionalEncoding(PositionalEncoding):
             Dropout rate.
         max_len : int
             Maximum input length.
+        dtype : str
+            dtype of param
     """
 
-    def __init__(self, d_model, dropout_rate, max_len=5000):
+    def __init__(self, d_model, dropout_rate, max_len=5000, dtype="float32"):
         """Initialize class."""
         super().__init__(
-            d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
-        x = paddle.ones([1], dtype="float32")
+            d_model=d_model,
+            dropout_rate=dropout_rate,
+            max_len=max_len,
+            dtype=dtype)
+        x = paddle.ones([1], dtype=self.dtype)
         self.alpha = paddle.create_parameter(
             shape=x.shape,
-            dtype=str(x.numpy().dtype),
+            dtype=self.dtype,
             default_initializer=paddle.nn.initializer.Assign(x))
 
     def reset_parameters(self):
         """Reset parameters."""
-        self.alpha = paddle.to_tensor(1.0)
+        self.alpha = paddle.ones([1])
 
     def forward(self, x):
         """Add positional encoding.
@@ -115,12 +129,12 @@ class ScaledPositionalEncoding(PositionalEncoding):
         ----------
             x : paddle.Tensor
                 Input tensor (batch, time, `*`).
-
         Returns
         ----------
             paddle.Tensor
                 Encoded tensor (batch, time, `*`).
         """
         self.extend_pe(x)
-        x = x + self.alpha * self.pe[:, :x.shape[1]]
+        T = paddle.shape(x)[1]
+        x = x + self.alpha * self.pe[:, :T]
         return self.dropout(x)
diff --git a/parakeet/modules/fastspeech2_transformer/encoder.py b/parakeet/modules/fastspeech2_transformer/encoder.py
index 630b50ff5..996e9dee0 100644
--- a/parakeet/modules/fastspeech2_transformer/encoder.py
+++ b/parakeet/modules/fastspeech2_transformer/encoder.py
@@ -185,6 +185,7 @@ class Encoder(nn.Layer):
         paddle.Tensor
             Mask tensor (#batch, time).
         """
+
         xs = self.embed(xs)
         xs, masks = self.encoders(xs, masks)
         if self.normalize_before:
diff --git a/parakeet/modules/layer_norm.py b/parakeet/modules/layer_norm.py
index 3bab823f2..a1c775fc8 100644
--- a/parakeet/modules/layer_norm.py
+++ b/parakeet/modules/layer_norm.py
@@ -44,6 +44,7 @@ class LayerNorm(paddle.nn.LayerNorm):
         paddle.Tensor
             Normalized tensor.
         """
+
         if self.dim == -1:
             return super(LayerNorm, self).forward(x)
         else:
@@ -54,9 +55,12 @@ class LayerNorm(paddle.nn.LayerNorm):
 
             orig_perm = list(range(len_dim))
             new_perm = orig_perm[:]
-            new_perm[self.dim], new_perm[len_dim -
-                                         1] = new_perm[len_dim -
-                                                       1], new_perm[self.dim]
+            # Python style item change is not able when converting dygraph to static graph.
+            # new_perm[self.dim], new_perm[len_dim -1] = new_perm[len_dim -1], new_perm[self.dim]
+            # use C++ style item change here
+            temp = new_perm[self.dim]
+            new_perm[self.dim] = new_perm[len_dim - 1]
+            new_perm[len_dim - 1] = temp
 
             return paddle.transpose(
                 super(LayerNorm, self).forward(paddle.transpose(x, new_perm)),
diff --git a/parakeet/modules/masked_fill.py b/parakeet/modules/masked_fill.py
index 34230f1c4..b32222547 100644
--- a/parakeet/modules/masked_fill.py
+++ b/parakeet/modules/masked_fill.py
@@ -25,12 +25,24 @@ def is_broadcastable(shp1, shp2):
     return True
 
 
+# assume that len(shp1) == len(shp2)
+def broadcast_shape(shp1, shp2):
+    result = []
+    for a, b in zip(shp1[::-1], shp2[::-1]):
+        result.append(max(a, b))
+    return result[::-1]
+
+
 def masked_fill(xs: paddle.Tensor,
                 mask: paddle.Tensor,
                 value: Union[float, int]):
-    assert is_broadcastable(xs.shape, mask.shape) is True
-    bshape = paddle.broadcast_shape(xs.shape, mask.shape)
+    # comment following line for converting dygraph to static graph. 
+    # assert is_broadcastable(xs.shape, mask.shape) is True
+    # bshape = paddle.broadcast_shape(xs.shape, mask.shape)   
+    bshape = broadcast_shape(xs.shape, mask.shape)
+    mask.stop_gradient = True
     mask = mask.broadcast_to(bshape)
+
     trues = paddle.ones_like(xs) * value
     mask = mask.cast(dtype=paddle.bool)
     xs = paddle.where(mask, trues, xs)
diff --git a/parakeet/modules/nets_utils.py b/parakeet/modules/nets_utils.py
index 47eae65d6..0696335a5 100644
--- a/parakeet/modules/nets_utils.py
+++ b/parakeet/modules/nets_utils.py
@@ -56,7 +56,7 @@ def make_pad_mask(lengths, length_dim=-1):
 
     Parameters
     ----------
-    lengths : LongTensor or List
+    lengths : LongTensor
             Batch of lengths (B,).
 
     Returns
@@ -77,17 +77,11 @@ def make_pad_mask(lengths, length_dim=-1):
     if length_dim == 0:
         raise ValueError("length_dim cannot be 0: {}".format(length_dim))
 
-    if not isinstance(lengths, list):
-        lengths = lengths.tolist()
-    bs = int(len(lengths))
-
-    maxlen = int(max(lengths))
-
+    bs = paddle.shape(lengths)[0]
+    maxlen = lengths.max()
     seq_range = paddle.arange(0, maxlen, dtype=paddle.int64)
     seq_range_expand = seq_range.unsqueeze(0).expand([bs, maxlen])
-
-    seq_length_expand = paddle.to_tensor(
-        lengths, dtype=seq_range_expand.dtype).unsqueeze(-1)
+    seq_length_expand = lengths.unsqueeze(-1)
     mask = seq_range_expand >= seq_length_expand
 
     return mask