Merge pull request #789 from PaddlePaddle/seed

seed all with log; and format
pull/792/head
Jackwaterveg 3 years ago committed by GitHub
commit 794294e9cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -64,7 +64,7 @@ def default_argument_parser():
help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
parser.add_argument("--seed", type=int, default=None,
help="seed to use for paddle, np and random. The default value is None")
help="seed to use for paddle, np and random. None or 0 for random, else set seed.")
# yapd: enable
return parser

@ -1,8 +1,21 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable
from .extension import Extension
def make_extension(trigger: Callable=None,
default_name: str=None,
priority: int=None,

@ -1,10 +1,23 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
import extension
import paddle
from paddle.io import DataLoader
from paddle.nn import Layer
import extension
from ..reporter import DictSummary
from ..reporter import report
from ..reporter import scope

@ -1,5 +1,16 @@
from typing import Callable
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
PRIORITY_WRITER = 300
PRIORITY_EDITOR = 200
PRIORITY_READER = 100

@ -1,3 +1,16 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from datetime import datetime
from pathlib import Path
@ -7,11 +20,10 @@ from typing import List
import jsonlines
from deepspeech.training.updaters.trainer import Trainer
from deepspeech.training.extensions import extension
from deepspeech.utils.mp_tools import rank_zero_only
from deepspeech.training.updaters.trainer import Trainer
from deepspeech.utils.log import Log
from deepspeech.utils.mp_tools import rank_zero_only
logger = Log(__name__).getlog()
@ -75,7 +87,7 @@ class Snapshot(extension.Extension):
"""Saving new snapshot and remove the oldest snapshot if needed."""
iteration = trainer.updater.state.iteration
epoch = trainer.updater.state.epoch
num = epoch if self.trigger[1] is 'epoch' else iteration
num = epoch if self.trigger[1] == 'epoch' else iteration
path = self.checkpoint_dir / f"{num}.pdz"
# add the new one

@ -1,3 +1,16 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from deepspeech.training.extensions import extension
from deepspeech.training.updaters.trainer import Trainer

@ -1,3 +1,16 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import math
from collections import defaultdict

@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import time
from pathlib import Path
import numpy as np
import paddle
from paddle import distributed as dist
from tensorboardX import SummaryWriter
@ -23,6 +21,7 @@ from tensorboardX import SummaryWriter
from deepspeech.utils import mp_tools
from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log
from deepspeech.utils.utility import seed_all
__all__ = ["Trainer"]
@ -95,13 +94,10 @@ class Trainer():
self.checkpoint_dir = None
self.iteration = 0
self.epoch = 0
if args.seed is not None:
self.set_seed(args.seed)
def set_seed(self, seed):
np.random.seed(seed)
random.seed(seed)
paddle.seed(seed)
if args.seed:
seed_all(args.seed)
logger.info(f"Set seed {args.seed}")
def setup(self):
"""Setup the experiment.
@ -182,7 +178,9 @@ class Trainer():
"""
self.epoch += 1
if self.parallel and hasattr(self.train_loader, "batch_sampler"):
self.train_loader.batch_sampler.set_epoch(self.epoch)
batch_sampler = self.train_loader.batch_sampler
if isinstance(batch_sampler, paddle.io.DistributedBatchSampler):
batch_sampler.set_epoch(self.epoch)
def train(self):
"""The training process control by epoch."""

@ -1,8 +1,23 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .interval_trigger import IntervalTrigger
def never_fail_trigger(trainer):
return False
def get_trigger(trigger):
if trigger is None:
return never_fail_trigger

@ -1,3 +1,17 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class IntervalTrigger():
"""A Predicate to do something every N cycle."""

@ -1,3 +1,17 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class LimitTrigger():
"""A Predicate to decide whether to stop."""

@ -1,3 +1,18 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class TimeTrigger():
"""Trigger based on a fixed time interval.
This trigger accepts iterations with a given interval time.

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -1,3 +1,16 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from typing import Optional
@ -11,13 +24,13 @@ from timer import timer
from deepspeech.training.reporter import report
from deepspeech.training.updaters.updater import UpdaterBase
from deepspeech.training.updaters.updater import UpdaterState
from deepspeech.utils.log import Log
__all__ = ["StandardUpdater"]
logger = Log(__name__).getlog()
class StandardUpdater(UpdaterBase):
"""An example of over-simplification. Things may not be that simple, but
you can subclass it to fit your need.
@ -142,7 +155,7 @@ class StandardUpdater(UpdaterBase):
"""Start a new epoch."""
# NOTE: all batch sampler for distributed training should
# subclass DistributedBatchSampler and implement `set_epoch` method
if hasattr(self.dataloader, "batch_sampler")
if hasattr(self.dataloader, "batch_sampler"):
batch_sampler = self.dataloader.batch_sampler
if isinstance(batch_sampler, DistributedBatchSampler):
batch_sampler.set_epoch(self.state.epoch)

@ -1,3 +1,16 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import traceback
from collections import OrderedDict

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import paddle
from deepspeech.utils.log import Log

@ -15,9 +15,19 @@
import distutils.util
import math
import os
import random
from typing import List
__all__ = ['print_arguments', 'add_arguments', "log_add"]
import numpy as np
import paddle
__all__ = ["seed_all", 'print_arguments', 'add_arguments', "log_add"]
def seed_all(seed: int=210329):
np.random.seed(seed)
random.seed(seed)
paddle.seed(seed)
def print_arguments(args, info=None):

@ -6,8 +6,6 @@
| data/manifest.dev | 1.645 ~ 12.533 |
| data/manifest.test | 1.859125 ~ 14.6999375 |
`jq '.feat_shape[0]' data/manifest.train | sort -un`
## Deepspeech2
| Model | Params | Release | Config | Test set | Loss | CER |

@ -1,5 +1,6 @@
coverage
gpustat
jsonlines
kaldiio
Pillow
pre-commit
@ -15,4 +16,3 @@ tensorboardX
textgrid
typeguard
yacs
jsonlines
Loading…
Cancel
Save