fix export, dataloader time log

pull/549/head
Hui Zhang 5 years ago
parent df5bc5e720
commit 29dbfa86ad

@ -305,11 +305,12 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
exit(-1)
def export(self):
self.infer_model.eval()
infer_model = DeepSpeech2InferModel.from_pretrained(
self.test_loader.dataset, self.config, self.args.checkpoint_path)
infer_model.eval()
feat_dim = self.test_loader.dataset.feature_size
paddle.jit.save(
self.infer_model,
self.args.export_path,
static_model = paddle.jit.to_static(
infer_model,
input_spec=[
paddle.static.InputSpec(
shape=[None, feat_dim, None],
@ -317,6 +318,8 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
paddle.static.InputSpec(shape=[None],
dtype='int64'), # audio_length, [B]
])
logger.info(f"Export code: {static_model.forward.code}")
paddle.jit.save(static_model, self.args.export_path)
def run_export(self):
try:
@ -349,12 +352,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
rnn_size=config.model.rnn_layer_size,
use_gru=config.model.use_gru,
share_rnn_weights=config.model.share_rnn_weights)
infer_model = DeepSpeech2InferModel.from_pretrained(
self.test_loader.dataset, config, self.args.checkpoint_path)
self.model = model
self.infer_model = infer_model
self.logger.info("Setup model!")
def setup_dataloader(self):

@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Positonal Encoding Module."""
import math

@ -70,7 +70,7 @@ class CTCLoss(nn.Layer):
Returns:
[paddle.Tensor]: scalar. If reduction is 'none', then (N), where N = \text{batch size}.
"""
# warp-ctc do softmax on activations
# warp-ctc need logits, and do softmax on logits by itself
# warp-ctc need activation with shape [T, B, V + 1]
# logits: (B, L, D) -> (L, B, D)
logits = logits.transpose([1, 0, 2])

@ -167,9 +167,17 @@ class Trainer():
self.new_epoch()
while self.epoch <= self.config.training.n_epoch:
try:
data_start_time = time.time()
for batch in self.train_loader:
dataload_time = time.time() - data_start_time
msg = "Train: Rank: {}, ".format(dist.get_rank())
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
msg += "dataloader time: {:>.3f}s, ".format(dataload_time)
self.logger.info(msg)
self.iteration += 1
self.train_batch(batch)
data_start_time = time.time()
except Exception as e:
self.logger.error(e)
pass

@ -0,0 +1,20 @@
#! /usr/bin/env bash
if [ $# != 2 ];then
echo "usage: export ckpt_path jit_model_path"
exit -1
fi
python3 -u ${BIN_DIR}/export.py \
--config conf/deepspeech2.yaml \
--checkpoint_path ${1} \
--export_path ${2}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
exit 0

@ -13,7 +13,6 @@ python3 -u ${BIN_DIR}/test.py \
--config conf/deepspeech2.yaml \
--output ckpt
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1

Loading…
Cancel
Save