|
|
@ -11,16 +11,21 @@
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
import os
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import argparse
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
|
|
|
|
|
from paddle.io import DataLoader
|
|
|
|
|
|
|
|
from paddle.io import DistributedBatchSampler
|
|
|
|
|
|
|
|
|
|
|
|
from paddleaudio.datasets.voxceleb import VoxCeleb1
|
|
|
|
from paddleaudio.datasets.voxceleb import VoxCeleb1
|
|
|
|
|
|
|
|
from paddlespeech.vector.datasets.batch import waveform_collate_fn
|
|
|
|
|
|
|
|
from paddlespeech.vector.layers.loss import AdditiveAngularMargin
|
|
|
|
|
|
|
|
from paddlespeech.vector.layers.loss import LogSoftmaxWrapper
|
|
|
|
from paddlespeech.vector.layers.lr import CyclicLRScheduler
|
|
|
|
from paddlespeech.vector.layers.lr import CyclicLRScheduler
|
|
|
|
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
|
|
|
|
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
|
|
|
|
from paddlespeech.vector.training.sid_model import SpeakerIdetification
|
|
|
|
from paddlespeech.vector.training.sid_model import SpeakerIdetification
|
|
|
|
from paddlespeech.vector.layers.loss import AdditiveAngularMargin, LogSoftmaxWrapper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(args):
|
|
|
|
def main(args):
|
|
|
|
# stage0: set the training device, cpu or gpu
|
|
|
|
# stage0: set the training device, cpu or gpu
|
|
|
@ -61,7 +66,6 @@ def main(args):
|
|
|
|
criterion = LogSoftmaxWrapper(
|
|
|
|
criterion = LogSoftmaxWrapper(
|
|
|
|
loss_fn=AdditiveAngularMargin(margin=0.2, scale=30))
|
|
|
|
loss_fn=AdditiveAngularMargin(margin=0.2, scale=30))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# stage7: confirm training start epoch
|
|
|
|
# stage7: confirm training start epoch
|
|
|
|
# if pre-trained model exists, start epoch confirmed by the pre-trained model
|
|
|
|
# if pre-trained model exists, start epoch confirmed by the pre-trained model
|
|
|
|
start_epoch = 0
|
|
|
|
start_epoch = 0
|
|
|
@ -90,6 +94,18 @@ def main(args):
|
|
|
|
except ValueError:
|
|
|
|
except ValueError:
|
|
|
|
pass
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# stage8: we build the batch sampler for paddle.DataLoader
|
|
|
|
|
|
|
|
train_sampler = DistributedBatchSampler(
|
|
|
|
|
|
|
|
train_ds, batch_size=args.batch_size, shuffle=True, drop_last=False)
|
|
|
|
|
|
|
|
train_loader = DataLoader(
|
|
|
|
|
|
|
|
train_ds,
|
|
|
|
|
|
|
|
batch_sampler=train_sampler,
|
|
|
|
|
|
|
|
num_workers=args.num_workers,
|
|
|
|
|
|
|
|
collate_fn=waveform_collate_fn,
|
|
|
|
|
|
|
|
return_list=True,
|
|
|
|
|
|
|
|
use_buffer_reader=True, )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
if __name__ == "__main__":
|
|
|
|
# yapf: disable
|
|
|
|
# yapf: disable
|
|
|
|
parser = argparse.ArgumentParser(__doc__)
|
|
|
|
parser = argparse.ArgumentParser(__doc__)
|
|
|
@ -109,6 +125,13 @@ if __name__ == "__main__":
|
|
|
|
type=str,
|
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
default=None,
|
|
|
|
help="Directory to load model checkpoint to contiune trainning.")
|
|
|
|
help="Directory to load model checkpoint to contiune trainning.")
|
|
|
|
|
|
|
|
parser.add_argument("--batch_size",
|
|
|
|
|
|
|
|
type=int, default=64,
|
|
|
|
|
|
|
|
help="Total examples' number in batch for training.")
|
|
|
|
|
|
|
|
parser.add_argument("--num_workers",
|
|
|
|
|
|
|
|
type=int,
|
|
|
|
|
|
|
|
default=0,
|
|
|
|
|
|
|
|
help="Number of workers in dataloader.")
|
|
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
args = parser.parse_args()
|
|
|
|
# yapf: enable
|
|
|
|
# yapf: enable
|
|
|
|