Merge pull request #953 from Jackwaterveg/fix_bug

[Bug fix] fix the bug of 'dev/null' and the test_export
pull/955/head
Hui Zhang 3 years ago committed by GitHub
commit 1372a08813
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -21,11 +21,6 @@ from typing import Optional
import jsonlines
import numpy as np
import paddle
from paddle import distributed as dist
from paddle import inference
from paddle.io import DataLoader
from yacs.config import CfgNode
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
@ -44,6 +39,10 @@ from deepspeech.utils import mp_tools
from deepspeech.utils.log import Autolog
from deepspeech.utils.log import Log
from deepspeech.utils.utility import UpdateConfig
from paddle import distributed as dist
from paddle import inference
from paddle.io import DataLoader
from yacs.config import CfgNode
logger = Log(__name__).getlog()
@ -412,6 +411,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
class DeepSpeech2ExportTester(DeepSpeech2Tester):
def __init__(self, config, args):
super().__init__(config, args)
self.apply_static = True
def compute_result_transcripts(self, audio, audio_len, vocab_list, cfg):
if self.args.model_type == "online":

@ -18,9 +18,6 @@ from contextlib import contextmanager
from pathlib import Path
import paddle
from paddle import distributed as dist
from tensorboardX import SummaryWriter
from deepspeech.training.reporter import ObsScope
from deepspeech.training.reporter import report
from deepspeech.training.timer import Timer
@ -31,6 +28,8 @@ from deepspeech.utils.log import Log
from deepspeech.utils.utility import all_version
from deepspeech.utils.utility import seed_all
from deepspeech.utils.utility import UpdateConfig
from paddle import distributed as dist
from tensorboardX import SummaryWriter
__all__ = ["Trainer"]
@ -348,8 +347,12 @@ class Trainer():
try:
with Timer("Test/Decode Done: {}"):
with self.eval():
self.restore()
self.test()
if hasattr(self,
"apply_static") and self.apply_static is True:
self.test()
else:
self.restore()
self.test()
except KeyboardInterrupt:
exit(-1)
@ -381,6 +384,8 @@ class Trainer():
elif self.args.checkpoint_path:
output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent
elif self.args.export_path:
output_dir = Path(self.args.export_path).expanduser().parent.parent
self.output_dir = output_dir
self.output_dir.mkdir(parents=True, exist_ok=True)

@ -13,7 +13,7 @@ ckpt_prefix=$2
model_type=$3
# download language model
bash local/download_lm_ch.sh > dev/null 2>&1
bash local/download_lm_ch.sh > /dev/null 2>&1
if [ $? -ne 0 ]; then
exit 1
fi

@ -13,7 +13,7 @@ jit_model_export_path=$2
model_type=$3
# download language model
bash local/download_lm_ch.sh > dev/null 2>&1
bash local/download_lm_ch.sh > /dev/null 2>&1
if [ $? -ne 0 ]; then
exit 1
fi

Loading…
Cancel
Save