using to_static

pull/782/head
Hui Zhang 4 years ago
parent 30a71de217
commit 27daa92a81

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Evaluation for U2 model.""" """Evaluation for U2 model."""
import cProfile import cProfile
from yacs.config import CfgNode
from deepspeech.training.cli import default_argument_parser from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.dynamic_import import dynamic_import from deepspeech.utils.dynamic_import import dynamic_import

@ -417,32 +417,32 @@ class U2STBaseModel(nn.Layer):
best_hyps = best_hyps[:, 1:] best_hyps = best_hyps[:, 1:]
return best_hyps return best_hyps
@jit.export @jit.to_static
def subsampling_rate(self) -> int: def subsampling_rate(self) -> int:
""" Export interface for c++ call, return subsampling_rate of the """ Export interface for c++ call, return subsampling_rate of the
model model
""" """
return self.encoder.embed.subsampling_rate return self.encoder.embed.subsampling_rate
@jit.export @jit.to_static
def right_context(self) -> int: def right_context(self) -> int:
""" Export interface for c++ call, return right_context of the model """ Export interface for c++ call, return right_context of the model
""" """
return self.encoder.embed.right_context return self.encoder.embed.right_context
@jit.export @jit.to_static
def sos_symbol(self) -> int: def sos_symbol(self) -> int:
""" Export interface for c++ call, return sos symbol id of the model """ Export interface for c++ call, return sos symbol id of the model
""" """
return self.sos return self.sos
@jit.export @jit.to_static
def eos_symbol(self) -> int: def eos_symbol(self) -> int:
""" Export interface for c++ call, return eos symbol id of the model """ Export interface for c++ call, return eos symbol id of the model
""" """
return self.eos return self.eos
@jit.export @jit.to_static
def forward_encoder_chunk( def forward_encoder_chunk(
self, self,
xs: paddle.Tensor, xs: paddle.Tensor,
@ -472,7 +472,7 @@ class U2STBaseModel(nn.Layer):
xs, offset, required_cache_size, subsampling_cache, xs, offset, required_cache_size, subsampling_cache,
elayers_output_cache, conformer_cnn_cache) elayers_output_cache, conformer_cnn_cache)
@jit.export @jit.to_static
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
""" Export interface for c++ call, apply linear transform and log """ Export interface for c++ call, apply linear transform and log
softmax before ctc softmax before ctc
@ -483,7 +483,7 @@ class U2STBaseModel(nn.Layer):
""" """
return self.ctc.log_softmax(xs) return self.ctc.log_softmax(xs)
@jit.export @jit.to_static
def forward_attention_decoder( def forward_attention_decoder(
self, self,
hyps: paddle.Tensor, hyps: paddle.Tensor,

Loading…
Cancel
Save