format code

pull/1273/head
Hui Zhang 3 years ago
parent 6f651d762e
commit 3a2db414e6

@ -19,8 +19,8 @@ import paddle
from paddle.inference import Config
from paddle.inference import create_predictor
from paddle.io import DataLoader
from yacs.config import CfgNode
from paddlespeech.s2t.io.collator import SpeechCollator
from paddlespeech.s2t.io.dataset import ManifestDataset
from paddlespeech.s2t.models.ds2 import DeepSpeech2Model

@ -17,8 +17,8 @@ import functools
import numpy as np
import paddle
from paddle.io import DataLoader
from yacs.config import CfgNode
from paddlespeech.s2t.io.collator import SpeechCollator
from paddlespeech.s2t.io.dataset import ManifestDataset
from paddlespeech.s2t.models.ds2 import DeepSpeech2Model

@ -13,6 +13,7 @@
# limitations under the License.
"""Export for DeepSpeech2 model."""
from yacs.config import CfgNode
from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2Tester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments

@ -13,6 +13,7 @@
# limitations under the License.
"""Evaluation for DeepSpeech2 model."""
from yacs.config import CfgNode
from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2ExportTester as ExportTester
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments

@ -13,8 +13,8 @@
# limitations under the License.
"""Trainer for DeepSpeech2 model."""
from paddle import distributed as dist
from yacs.config import CfgNode
from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2Trainer as Trainer
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments

@ -13,6 +13,7 @@
# limitations under the License.
"""Export for U2 model."""
from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2.model import U2Tester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments

@ -16,8 +16,8 @@ import cProfile
import os
from paddle import distributed as dist
from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2.model import U2Trainer as Trainer
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments

@ -42,6 +42,7 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
logger = Log(__name__).getlog()
class U2Trainer(Trainer):
def __init__(self, config, args):
super().__init__(config, args)

@ -13,6 +13,7 @@
# limitations under the License.
"""Export for U2 model."""
from yacs.config import CfgNode
from paddlespeech.s2t.exps.u2_st.model import U2STTester as Tester
from paddlespeech.s2t.training.cli import default_argument_parser
from paddlespeech.s2t.utils.utility import print_arguments

@ -208,8 +208,7 @@ class U2STTrainer(Trainer):
k.split(',')) == 2 else ""
msg += ","
msg = msg[:-1] # remove the last ","
if (batch_index + 1
) % self.config.log_interval == 0:
if (batch_index + 1) % self.config.log_interval == 0:
logger.info(msg)
except Exception as e:
logger.error(e)
@ -260,7 +259,8 @@ class U2STTrainer(Trainer):
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.preprocess_config, # aug will be off when train_mode=False
preprocess_conf=config.
preprocess_config, # aug will be off when train_mode=False
n_iter_processes=config.num_workers,
subsampling_factor=1,
load_aux_output=load_transcript,
@ -281,7 +281,8 @@ class U2STTrainer(Trainer):
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.preprocess_config, # aug will be off when train_mode=False
preprocess_conf=config.
preprocess_config, # aug will be off when train_mode=False
n_iter_processes=config.num_workers,
subsampling_factor=1,
load_aux_output=load_transcript,
@ -290,7 +291,8 @@ class U2STTrainer(Trainer):
logger.info("Setup train/valid Dataloader!")
else:
# test dataset, return raw text
decode_batch_size = config.get('decode',dict()).get('decode_batch_size', 1)
decode_batch_size = config.get('decode', dict()).get(
'decode_batch_size', 1)
self.test_loader = BatchDataLoader(
json_file=config.test_manifest,
train_mode=False,
@ -305,7 +307,8 @@ class U2STTrainer(Trainer):
batch_frames_in=0,
batch_frames_out=0,
batch_frames_inout=0,
preprocess_conf=config.preprocess_config, # aug will be off when train_mode=False
preprocess_conf=config.
preprocess_config, # aug will be off when train_mode=False
n_iter_processes=config.num_workers,
subsampling_factor=1,
num_encs=1,

@ -119,6 +119,7 @@ class DeepSpeech2Model(nn.Layer):
before softmax) and a ctc cost layer.
:rtype: tuple of LayerOutput
"""
def __init__(self,
feat_size,
dict_size,

@ -243,6 +243,7 @@ class DeepSpeech2ModelOnline(nn.Layer):
before softmax) and a ctc cost layer.
:rtype: tuple of LayerOutput
"""
def __init__(
self,
feat_size,

@ -59,6 +59,7 @@ logger = Log(__name__).getlog()
class U2BaseModel(ASRInterface, nn.Layer):
"""CTC-Attention hybrid Encoder-Decoder model"""
def __init__(self,
vocab_size: int,
encoder: TransformerEncoder,

@ -51,6 +51,7 @@ logger = Log(__name__).getlog()
class U2STBaseModel(nn.Layer):
"""CTC-Attention hybrid Encoder-Decoder model"""
def __init__(self,
vocab_size: int,
encoder: TransformerEncoder,

@ -39,7 +39,6 @@ except ImportError:
except Exception as e:
logger.info("paddlespeech_ctcdecoders not installed!")
__all__ = ['CTCDecoder']

Loading…
Cancel
Save