Merge pull request #955 from Jackwaterveg/fix

fix the run_test in test_export
pull/961/head
Hui Zhang 3 years ago committed by GitHub
commit 2fa681237f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -21,6 +21,11 @@ from typing import Optional
import jsonlines
import numpy as np
import paddle
from paddle import distributed as dist
from paddle import inference
from paddle.io import DataLoader
from yacs.config import CfgNode
from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
from deepspeech.io.collator import SpeechCollator
from deepspeech.io.dataset import ManifestDataset
@ -32,6 +37,7 @@ from deepspeech.models.ds2_online import DeepSpeech2InferModelOnline
from deepspeech.models.ds2_online import DeepSpeech2ModelOnline
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.reporter import report
from deepspeech.training.timer import Timer
from deepspeech.training.trainer import Trainer
from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools
@ -39,10 +45,6 @@ from deepspeech.utils import mp_tools
from deepspeech.utils.log import Autolog
from deepspeech.utils.log import Log
from deepspeech.utils.utility import UpdateConfig
from paddle import distributed as dist
from paddle import inference
from paddle.io import DataLoader
from yacs.config import CfgNode
logger = Log(__name__).getlog()
@ -441,6 +443,15 @@ class DeepSpeech2ExportTester(DeepSpeech2Tester):
return result_transcripts
def run_test(self):
"""Do Test/Decode"""
try:
with Timer("Test/Decode Done: {}"):
with self.eval():
self.test()
except KeyboardInterrupt:
exit(-1)
def static_forward_online(self, audio, audio_len,
decoder_chunk_size: int=1):
"""

@ -18,6 +18,9 @@ from contextlib import contextmanager
from pathlib import Path
import paddle
from paddle import distributed as dist
from tensorboardX import SummaryWriter
from deepspeech.training.reporter import ObsScope
from deepspeech.training.reporter import report
from deepspeech.training.timer import Timer
@ -28,8 +31,6 @@ from deepspeech.utils.log import Log
from deepspeech.utils.utility import all_version
from deepspeech.utils.utility import seed_all
from deepspeech.utils.utility import UpdateConfig
from paddle import distributed as dist
from tensorboardX import SummaryWriter
__all__ = ["Trainer"]
@ -347,12 +348,8 @@ class Trainer():
try:
with Timer("Test/Decode Done: {}"):
with self.eval():
if hasattr(self,
"apply_static") and self.apply_static is True:
self.test()
else:
self.restore()
self.test()
self.restore()
self.test()
except KeyboardInterrupt:
exit(-1)

Loading…
Cancel
Save