using to_static

pull/782/head
Hui Zhang 3 years ago
parent 30a71de217
commit 27daa92a81

@ -13,6 +13,7 @@
# limitations under the License.
"""Evaluation for U2 model."""
import cProfile
from yacs.config import CfgNode
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.dynamic_import import dynamic_import

@ -417,32 +417,32 @@ class U2STBaseModel(nn.Layer):
best_hyps = best_hyps[:, 1:]
return best_hyps
@jit.export
@jit.to_static
def subsampling_rate(self) -> int:
""" Export interface for c++ call, return subsampling_rate of the
model
"""
return self.encoder.embed.subsampling_rate
@jit.export
@jit.to_static
def right_context(self) -> int:
""" Export interface for c++ call, return right_context of the model
"""
return self.encoder.embed.right_context
@jit.export
@jit.to_static
def sos_symbol(self) -> int:
""" Export interface for c++ call, return sos symbol id of the model
"""
return self.sos
@jit.export
@jit.to_static
def eos_symbol(self) -> int:
""" Export interface for c++ call, return eos symbol id of the model
"""
return self.eos
@jit.export
@jit.to_static
def forward_encoder_chunk(
self,
xs: paddle.Tensor,
@ -472,7 +472,7 @@ class U2STBaseModel(nn.Layer):
xs, offset, required_cache_size, subsampling_cache,
elayers_output_cache, conformer_cnn_cache)
@jit.export
@jit.to_static
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
""" Export interface for c++ call, apply linear transform and log
softmax before ctc
@ -483,7 +483,7 @@ class U2STBaseModel(nn.Layer):
"""
return self.ctc.log_softmax(xs)
@jit.export
@jit.to_static
def forward_attention_decoder(
self,
hyps: paddle.Tensor,

Loading…
Cancel
Save