You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/examples/audio/codec/dac/train.py

118 lines
3.3 KiB

# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training script for DAC model.
This script demonstrates how to train the DAC model with distributed training support.
"""
import argparse
import os
from pathlib import Path
import paddle
import paddle.distributed as dist
from paddle.io import DataLoader, BatchSampler, DistributedBatchSampler
import yaml
from paddlespeech.audio.codec.dac.model import DACModel
from paddlespeech.audio.codec.dac.trainer import DACTrainer
from paddlespeech.audio.codec.dac.processor import DACProcessor
# TODO: Import dataset classes once implemented
def main(args):
"""Main training function.
Args:
args: Command line arguments
"""
# Setup distributed training environment
if args.ngpus > 1:
dist.init_parallel_env()
# Load configuration
with open(args.config, 'r') as f:
config = yaml.safe_load(f)
# Create output directory
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Setup model
model = DACModel(**config['model'])
if args.ngpus > 1:
model = paddle.DataParallel(model)
# Setup optimizer
optimizer = paddle.optimizer.Adam(
parameters=model.parameters(),
learning_rate=config['training']['lr'],
weight_decay=config['training']['weight_decay'])
# TODO: Setup dataset and dataloader
# This is a placeholder for the dataset setup
# train_dataset = ...
# valid_dataset = ...
# batch_sampler = BatchSampler(...)
# if args.ngpus > 1:
# batch_sampler = DistributedBatchSampler(...)
# train_dataloader = DataLoader(...)
# valid_dataloader = DataLoader(...)
# Setup trainer
trainer = DACTrainer(
model=model,
optimizer=optimizer,
dataloader=None, # TODO: Replace with actual train_dataloader
output_dir=output_dir,
config=config,
max_epoch=args.max_epoch)
# Run training
trainer.run()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train DAC model")
parser.add_argument(
"--config",
type=str,
required=True,
help="Path to configuration file")
parser.add_argument(
"--output-dir",
type=str,
required=True,
help="Directory to save model checkpoints and logs")
parser.add_argument(
"--ngpus",
type=int,
default=1,
help="Number of GPUs for distributed training")
parser.add_argument(
"--max-epoch",
type=int,
default=200,
help="Maximum number of training epochs")
args = parser.parse_args()
main(args)