|
|
@ -19,9 +19,8 @@ from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
|
from paddle import distributed as dist
|
|
|
|
from paddle import distributed as dist
|
|
|
|
import pdb
|
|
|
|
|
|
|
|
pdb.set_trace()
|
|
|
|
|
|
|
|
dist.init_parallel_env()
|
|
|
|
dist.init_parallel_env()
|
|
|
|
|
|
|
|
|
|
|
|
from visualdl import LogWriter
|
|
|
|
from visualdl import LogWriter
|
|
|
|
|
|
|
|
|
|
|
|
from paddlespeech.s2t.training.reporter import ObsScope
|
|
|
|
from paddlespeech.s2t.training.reporter import ObsScope
|
|
|
@ -125,9 +124,6 @@ class Trainer():
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
raise Exception("invalid device")
|
|
|
|
raise Exception("invalid device")
|
|
|
|
|
|
|
|
|
|
|
|
if self.parallel:
|
|
|
|
|
|
|
|
self.init_parallel()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.checkpoint = Checkpoint(
|
|
|
|
self.checkpoint = Checkpoint(
|
|
|
|
kbest_n=self.config.checkpoint.kbest_n,
|
|
|
|
kbest_n=self.config.checkpoint.kbest_n,
|
|
|
|
latest_n=self.config.checkpoint.latest_n)
|
|
|
|
latest_n=self.config.checkpoint.latest_n)
|
|
|
@ -176,11 +172,6 @@ class Trainer():
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
return self.args.ngpu > 1
|
|
|
|
return self.args.ngpu > 1
|
|
|
|
|
|
|
|
|
|
|
|
def init_parallel(self):
|
|
|
|
|
|
|
|
"""Init environment for multiprocess training.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
# dist.init_parallel_env()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@mp_tools.rank_zero_only
|
|
|
|
@mp_tools.rank_zero_only
|
|
|
|
def save(self, tag=None, infos: dict=None):
|
|
|
|
def save(self, tag=None, infos: dict=None):
|
|
|
|
"""Save checkpoint (model parameters and optimizer states).
|
|
|
|
"""Save checkpoint (model parameters and optimizer states).
|
|
|
|