From 02083cdbd6a25156e9ec97a8e62f00d3a1ec12e2 Mon Sep 17 00:00:00 2001 From: huangyuxin Date: Mon, 1 Nov 2021 08:37:25 +0000 Subject: [PATCH] fix the bug of 'dev/null' and the test_export --- deepspeech/exps/deepspeech2/model.py | 10 +++++----- deepspeech/training/trainer.py | 15 ++++++++++----- examples/aishell/s0/local/test.sh | 2 +- examples/aishell/s0/local/test_export.sh | 2 +- 4 files changed, 17 insertions(+), 12 deletions(-) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index 710630a78..56743629b 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -21,11 +21,6 @@ from typing import Optional import jsonlines import numpy as np import paddle -from paddle import distributed as dist -from paddle import inference -from paddle.io import DataLoader -from yacs.config import CfgNode - from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer from deepspeech.io.collator import SpeechCollator from deepspeech.io.dataset import ManifestDataset @@ -44,6 +39,10 @@ from deepspeech.utils import mp_tools from deepspeech.utils.log import Autolog from deepspeech.utils.log import Log from deepspeech.utils.utility import UpdateConfig +from paddle import distributed as dist +from paddle import inference +from paddle.io import DataLoader +from yacs.config import CfgNode logger = Log(__name__).getlog() @@ -412,6 +411,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): class DeepSpeech2ExportTester(DeepSpeech2Tester): def __init__(self, config, args): super().__init__(config, args) + self.apply_static = True def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg): if self.args.model_type == "online": diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 2da838047..ddde1e885 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -18,9 +18,6 @@ from contextlib import contextmanager from pathlib import Path import paddle -from paddle import distributed as dist -from tensorboardX import SummaryWriter - from deepspeech.training.reporter import ObsScope from deepspeech.training.reporter import report from deepspeech.training.timer import Timer @@ -31,6 +28,8 @@ from deepspeech.utils.log import Log from deepspeech.utils.utility import all_version from deepspeech.utils.utility import seed_all from deepspeech.utils.utility import UpdateConfig +from paddle import distributed as dist +from tensorboardX import SummaryWriter __all__ = ["Trainer"] @@ -348,8 +347,12 @@ class Trainer(): try: with Timer("Test/Decode Done: {}"): with self.eval(): - self.restore() - self.test() + if hasattr(self, + "apply_static") and self.apply_static is True: + self.test() + else: + self.restore() + self.test() except KeyboardInterrupt: exit(-1) @@ -381,6 +384,8 @@ class Trainer(): elif self.args.checkpoint_path: output_dir = Path( self.args.checkpoint_path).expanduser().parent.parent + elif self.args.export_path: + output_dir = Path(self.args.export_path).expanduser().parent.parent self.output_dir = output_dir self.output_dir.mkdir(parents=True, exist_ok=True) diff --git a/examples/aishell/s0/local/test.sh b/examples/aishell/s0/local/test.sh index 64d725030..d539ac494 100755 --- a/examples/aishell/s0/local/test.sh +++ b/examples/aishell/s0/local/test.sh @@ -13,7 +13,7 @@ ckpt_prefix=$2 model_type=$3 # download language model -bash local/download_lm_ch.sh > dev/null 2>&1 +bash local/download_lm_ch.sh > /dev/null 2>&1 if [ $? -ne 0 ]; then exit 1 fi diff --git a/examples/aishell/s0/local/test_export.sh b/examples/aishell/s0/local/test_export.sh index 71469753d..f0a30ce56 100755 --- a/examples/aishell/s0/local/test_export.sh +++ b/examples/aishell/s0/local/test_export.sh @@ -13,7 +13,7 @@ jit_model_export_path=$2 model_type=$3 # download language model -bash local/download_lm_ch.sh > dev/null 2>&1 +bash local/download_lm_ch.sh > /dev/null 2>&1 if [ $? -ne 0 ]; then exit 1 fi