|
|
@ -21,8 +21,8 @@ from yacs.config import CfgNode
|
|
|
|
from deepspeech.modules.conv import ConvStack
|
|
|
|
from deepspeech.modules.conv import ConvStack
|
|
|
|
from deepspeech.modules.ctc import CTCDecoder
|
|
|
|
from deepspeech.modules.ctc import CTCDecoder
|
|
|
|
from deepspeech.modules.rnn import RNNStack
|
|
|
|
from deepspeech.modules.rnn import RNNStack
|
|
|
|
from deepspeech.utils import checkpoint
|
|
|
|
|
|
|
|
from deepspeech.utils import layer_tools
|
|
|
|
from deepspeech.utils import layer_tools
|
|
|
|
|
|
|
|
from deepspeech.utils.checkpoint import Checkpoint
|
|
|
|
from deepspeech.utils.log import Log
|
|
|
|
from deepspeech.utils.log import Log
|
|
|
|
|
|
|
|
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
|
logger = Log(__name__).getlog()
|
|
|
@ -222,7 +222,7 @@ class DeepSpeech2Model(nn.Layer):
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|
rnn_size=config.model.rnn_layer_size,
|
|
|
|
use_gru=config.model.use_gru,
|
|
|
|
use_gru=config.model.use_gru,
|
|
|
|
share_rnn_weights=config.model.share_rnn_weights)
|
|
|
|
share_rnn_weights=config.model.share_rnn_weights)
|
|
|
|
infos = checkpoint.load_parameters(
|
|
|
|
infos = Checkpoint().load_parameters(
|
|
|
|
model, checkpoint_path=checkpoint_path)
|
|
|
|
model, checkpoint_path=checkpoint_path)
|
|
|
|
logger.info(f"checkpoint info: {infos}")
|
|
|
|
logger.info(f"checkpoint info: {infos}")
|
|
|
|
layer_tools.summary(model)
|
|
|
|
layer_tools.summary(model)
|
|
|
|