Merge pull request #789 from PaddlePaddle/seed

seed all with log; and format
pull/792/head
Jackwaterveg 3 years ago committed by GitHub
commit 794294e9cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -64,7 +64,7 @@ def default_argument_parser():
help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
parser.add_argument("--seed", type=int, default=None, parser.add_argument("--seed", type=int, default=None,
help="seed to use for paddle, np and random. The default value is None") help="seed to use for paddle, np and random. None or 0 for random, else set seed.")
# yapd: enable # yapd: enable
return parser return parser

@ -1,8 +1,21 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable from typing import Callable
from .extension import Extension from .extension import Extension
def make_extension(trigger: Callable=None, def make_extension(trigger: Callable=None,
default_name: str=None, default_name: str=None,
priority: int=None, priority: int=None,

@ -1,10 +1,23 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict from typing import Dict
import extension
import paddle import paddle
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.nn import Layer from paddle.nn import Layer
import extension
from ..reporter import DictSummary from ..reporter import DictSummary
from ..reporter import report from ..reporter import report
from ..reporter import scope from ..reporter import scope

@ -1,5 +1,16 @@
from typing import Callable # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
PRIORITY_WRITER = 300 PRIORITY_WRITER = 300
PRIORITY_EDITOR = 200 PRIORITY_EDITOR = 200
PRIORITY_READER = 100 PRIORITY_READER = 100

@ -1,3 +1,16 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@ -7,11 +20,10 @@ from typing import List
import jsonlines import jsonlines
from deepspeech.training.updaters.trainer import Trainer
from deepspeech.training.extensions import extension from deepspeech.training.extensions import extension
from deepspeech.utils.mp_tools import rank_zero_only from deepspeech.training.updaters.trainer import Trainer
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
from deepspeech.utils.mp_tools import rank_zero_only
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -75,7 +87,7 @@ class Snapshot(extension.Extension):
"""Saving new snapshot and remove the oldest snapshot if needed.""" """Saving new snapshot and remove the oldest snapshot if needed."""
iteration = trainer.updater.state.iteration iteration = trainer.updater.state.iteration
epoch = trainer.updater.state.epoch epoch = trainer.updater.state.epoch
num = epoch if self.trigger[1] is 'epoch' else iteration num = epoch if self.trigger[1] == 'epoch' else iteration
path = self.checkpoint_dir / f"{num}.pdz" path = self.checkpoint_dir / f"{num}.pdz"
# add the new one # add the new one

@ -1,3 +1,16 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from deepspeech.training.extensions import extension from deepspeech.training.extensions import extension
from deepspeech.training.updaters.trainer import Trainer from deepspeech.training.updaters.trainer import Trainer

@ -1,3 +1,16 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib import contextlib
import math import math
from collections import defaultdict from collections import defaultdict

@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random
import time import time
from pathlib import Path from pathlib import Path
import numpy as np
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
@ -23,6 +21,7 @@ from tensorboardX import SummaryWriter
from deepspeech.utils import mp_tools from deepspeech.utils import mp_tools
from deepspeech.utils.checkpoint import Checkpoint from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
from deepspeech.utils.utility import seed_all
__all__ = ["Trainer"] __all__ = ["Trainer"]
@ -95,13 +94,10 @@ class Trainer():
self.checkpoint_dir = None self.checkpoint_dir = None
self.iteration = 0 self.iteration = 0
self.epoch = 0 self.epoch = 0
if args.seed is not None:
self.set_seed(args.seed)
def set_seed(self, seed): if args.seed:
np.random.seed(seed) seed_all(args.seed)
random.seed(seed) logger.info(f"Set seed {args.seed}")
paddle.seed(seed)
def setup(self): def setup(self):
"""Setup the experiment. """Setup the experiment.
@ -182,7 +178,9 @@ class Trainer():
""" """
self.epoch += 1 self.epoch += 1
if self.parallel and hasattr(self.train_loader, "batch_sampler"): if self.parallel and hasattr(self.train_loader, "batch_sampler"):
self.train_loader.batch_sampler.set_epoch(self.epoch) batch_sampler = self.train_loader.batch_sampler
if isinstance(batch_sampler, paddle.io.DistributedBatchSampler):
batch_sampler.set_epoch(self.epoch)
def train(self): def train(self):
"""The training process control by epoch.""" """The training process control by epoch."""

@ -1,8 +1,23 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .interval_trigger import IntervalTrigger from .interval_trigger import IntervalTrigger
def never_fail_trigger(trainer): def never_fail_trigger(trainer):
return False return False
def get_trigger(trigger): def get_trigger(trigger):
if trigger is None: if trigger is None:
return never_fail_trigger return never_fail_trigger

@ -1,3 +1,17 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class IntervalTrigger(): class IntervalTrigger():
"""A Predicate to do something every N cycle.""" """A Predicate to do something every N cycle."""

@ -1,3 +1,17 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class LimitTrigger(): class LimitTrigger():
"""A Predicate to decide whether to stop.""" """A Predicate to decide whether to stop."""

@ -1,3 +1,18 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class TimeTrigger(): class TimeTrigger():
"""Trigger based on a fixed time interval. """Trigger based on a fixed time interval.
This trigger accepts iterations with a given interval time. This trigger accepts iterations with a given interval time.

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -1,3 +1,16 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict from typing import Dict
from typing import Optional from typing import Optional
@ -11,13 +24,13 @@ from timer import timer
from deepspeech.training.reporter import report from deepspeech.training.reporter import report
from deepspeech.training.updaters.updater import UpdaterBase from deepspeech.training.updaters.updater import UpdaterBase
from deepspeech.training.updaters.updater import UpdaterState from deepspeech.training.updaters.updater import UpdaterState
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
__all__ = ["StandardUpdater"] __all__ = ["StandardUpdater"]
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
class StandardUpdater(UpdaterBase): class StandardUpdater(UpdaterBase):
"""An example of over-simplification. Things may not be that simple, but """An example of over-simplification. Things may not be that simple, but
you can subclass it to fit your need. you can subclass it to fit your need.
@ -142,7 +155,7 @@ class StandardUpdater(UpdaterBase):
"""Start a new epoch.""" """Start a new epoch."""
# NOTE: all batch sampler for distributed training should # NOTE: all batch sampler for distributed training should
# subclass DistributedBatchSampler and implement `set_epoch` method # subclass DistributedBatchSampler and implement `set_epoch` method
if hasattr(self.dataloader, "batch_sampler") if hasattr(self.dataloader, "batch_sampler"):
batch_sampler = self.dataloader.batch_sampler batch_sampler = self.dataloader.batch_sampler
if isinstance(batch_sampler, DistributedBatchSampler): if isinstance(batch_sampler, DistributedBatchSampler):
batch_sampler.set_epoch(self.state.epoch) batch_sampler.set_epoch(self.state.epoch)

@ -1,3 +1,16 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys import sys
import traceback import traceback
from collections import OrderedDict from collections import OrderedDict

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
import paddle import paddle
from deepspeech.utils.log import Log from deepspeech.utils.log import Log

@ -15,9 +15,19 @@
import distutils.util import distutils.util
import math import math
import os import os
import random
from typing import List from typing import List
__all__ = ['print_arguments', 'add_arguments', "log_add"] import numpy as np
import paddle
__all__ = ["seed_all", 'print_arguments', 'add_arguments', "log_add"]
def seed_all(seed: int=210329):
np.random.seed(seed)
random.seed(seed)
paddle.seed(seed)
def print_arguments(args, info=None): def print_arguments(args, info=None):

@ -6,8 +6,6 @@
| data/manifest.dev | 1.645 ~ 12.533 | | data/manifest.dev | 1.645 ~ 12.533 |
| data/manifest.test | 1.859125 ~ 14.6999375 | | data/manifest.test | 1.859125 ~ 14.6999375 |
`jq '.feat_shape[0]' data/manifest.train | sort -un`
## Deepspeech2 ## Deepspeech2
| Model | Params | Release | Config | Test set | Loss | CER | | Model | Params | Release | Config | Test set | Loss | CER |

@ -1,5 +1,6 @@
coverage coverage
gpustat gpustat
jsonlines
kaldiio kaldiio
Pillow Pillow
pre-commit pre-commit
@ -15,4 +16,3 @@ tensorboardX
textgrid textgrid
typeguard typeguard
yacs yacs
jsonlines
Loading…
Cancel
Save