pull/2324/head
tianhao zhang 3 years ago
parent 80fc0ef71a
commit c314f8b769

@ -114,6 +114,7 @@ if not hasattr(paddle.Tensor, 'new_full'):
paddle.Tensor.new_full = new_full
paddle.static.Variable.new_full = new_full
def contiguous(xs: paddle.Tensor) -> paddle.Tensor:
return xs

@ -26,8 +26,8 @@ from paddle import distributed as dist
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
from paddlespeech.s2t.io.dataloader import BatchDataLoader
from paddlespeech.s2t.io.dataloader import StreamDataLoader
from paddlespeech.s2t.io.dataloader import DataLoaderFactory
from paddlespeech.s2t.io.dataloader import StreamDataLoader
from paddlespeech.s2t.models.u2 import U2Model
from paddlespeech.s2t.training.optimizer import OptimizerFactory
from paddlespeech.s2t.training.reporter import ObsScope
@ -109,7 +109,8 @@ class U2Trainer(Trainer):
def valid(self):
self.model.eval()
if not self.use_streamdata:
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
logger.info(
f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list)
num_seen_utts = 1
total_loss = 0.0
@ -136,7 +137,8 @@ class U2Trainer(Trainer):
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata:
msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
msg += "batch: {}/{}, ".format(i + 1,
len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items())
logger.info(msg)
@ -157,7 +159,8 @@ class U2Trainer(Trainer):
self.before_train()
if not self.use_streamdata:
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
logger.info(
f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"):
self.model.train()
@ -225,14 +228,18 @@ class U2Trainer(Trainer):
config = self.config.clone()
self.use_streamdata = config.get("use_stream_data", False)
if self.train:
self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args)
self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args)
self.train_loader = DataLoaderFactory.get_dataloader(
'train', config, self.args)
self.valid_loader = DataLoaderFactory.get_dataloader(
'valid', config, self.args)
logger.info("Setup train/valid Dataloader!")
else:
decode_batch_size = config.get('decode', dict()).get(
'decode_batch_size', 1)
self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args)
self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args)
self.test_loader = DataLoaderFactory.get_dataloader('test', config,
self.args)
self.align_loader = DataLoaderFactory.get_dataloader(
'align', config, self.args)
logger.info("Setup test/align Dataloader!")
def setup_model(self):

@ -105,7 +105,8 @@ class U2Trainer(Trainer):
def valid(self):
self.model.eval()
if not self.use_streamdata:
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
logger.info(
f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list)
num_seen_utts = 1
total_loss = 0.0
@ -133,7 +134,8 @@ class U2Trainer(Trainer):
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata:
msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
msg += "batch: {}/{}, ".format(i + 1,
len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items())
logger.info(msg)
@ -153,7 +155,8 @@ class U2Trainer(Trainer):
self.before_train()
if not self.use_streamdata:
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
logger.info(
f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"):
self.model.train()
@ -165,8 +168,8 @@ class U2Trainer(Trainer):
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata:
msg += "batch : {}/{}, ".format(batch_index + 1,
len(self.train_loader))
msg += "batch : {}/{}, ".format(
batch_index + 1, len(self.train_loader))
msg += "lr: {:>.8f}, ".format(self.lr_scheduler())
msg += "data time: {:>.3f}s, ".format(dataload_time)
self.train_batch(batch_index, batch, msg)
@ -204,21 +207,24 @@ class U2Trainer(Trainer):
self.use_streamdata = config.get("use_stream_data", False)
if self.train:
config = self.config.clone()
self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args)
self.train_loader = DataLoaderFactory.get_dataloader(
'train', config, self.args)
config = self.config.clone()
config['preprocess_config'] = None
self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args)
self.valid_loader = DataLoaderFactory.get_dataloader(
'valid', config, self.args)
logger.info("Setup train/valid Dataloader!")
else:
config = self.config.clone()
config['preprocess_config'] = None
self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args)
self.test_loader = DataLoaderFactory.get_dataloader('test', config,
self.args)
config = self.config.clone()
config['preprocess_config'] = None
self.align_loader = DataLoaderFactory.get_dataloader('align', config, self.args)
self.align_loader = DataLoaderFactory.get_dataloader(
'align', config, self.args)
logger.info("Setup test/align Dataloader!")
def setup_model(self):
config = self.config

@ -121,7 +121,8 @@ class U2STTrainer(Trainer):
def valid(self):
self.model.eval()
if not self.use_streamdata:
logger.info(f"Valid Total Examples: {len(self.valid_loader.dataset)}")
logger.info(
f"Valid Total Examples: {len(self.valid_loader.dataset)}")
valid_losses = defaultdict(list)
num_seen_utts = 1
total_loss = 0.0
@ -155,7 +156,8 @@ class U2STTrainer(Trainer):
msg += "epoch: {}, ".format(self.epoch)
msg += "step: {}, ".format(self.iteration)
if not self.use_streamdata:
msg += "batch: {}/{}, ".format(i + 1, len(self.valid_loader))
msg += "batch: {}/{}, ".format(i + 1,
len(self.valid_loader))
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_dump.items())
logger.info(msg)
@ -175,7 +177,8 @@ class U2STTrainer(Trainer):
self.before_train()
if not self.use_streamdata:
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
logger.info(
f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.n_epoch:
with Timer("Epoch-Train Time Cost: {}"):
self.model.train()
@ -248,14 +251,16 @@ class U2STTrainer(Trainer):
config['load_transcript'] = load_transcript
self.use_streamdata = config.get("use_stream_data", False)
if self.train:
self.train_loader = DataLoaderFactory.get_dataloader('train', config, self.args)
self.valid_loader = DataLoaderFactory.get_dataloader('valid', config, self.args)
self.train_loader = DataLoaderFactory.get_dataloader(
'train', config, self.args)
self.valid_loader = DataLoaderFactory.get_dataloader(
'valid', config, self.args)
logger.info("Setup train/valid Dataloader!")
else:
self.test_loader = DataLoaderFactory.get_dataloader('test', config, self.args)
self.test_loader = DataLoaderFactory.get_dataloader('test', config,
self.args)
logger.info("Setup test Dataloader!")
def setup_model(self):
config = self.config
model_conf = config

@ -22,17 +22,16 @@ import paddle
from paddle.io import BatchSampler
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from yacs.config import CfgNode
import paddlespeech.audio.streamdata as streamdata
from paddlespeech.audio.text.text_featurizer import TextFeaturizer
from paddlespeech.s2t.io.batchfy import make_batchset
from paddlespeech.s2t.io.converter import CustomConverter
from paddlespeech.s2t.io.dataset import TransformDataset
from paddlespeech.s2t.io.reader import LoadInputsAndTargets
from paddlespeech.s2t.utils.log import Log
import paddlespeech.audio.streamdata as streamdata
from paddlespeech.audio.text.text_featurizer import TextFeaturizer
from yacs.config import CfgNode
__all__ = ["BatchDataLoader", "StreamDataLoader"]
logger = Log(__name__).getlog()
@ -61,6 +60,7 @@ def batch_collate(x):
"""
return x[0]
def read_preprocess_cfg(preprocess_conf_file):
augment_conf = dict()
preprocess_cfg = CfgNode(new_allowed=True)
@ -84,6 +84,7 @@ def read_preprocess_cfg(preprocess_conf_file):
augment_conf['t_replace_with_zero'] = process['replace_with_zero']
return augment_conf
class StreamDataLoader():
def __init__(self,
manifest_file: str,
@ -131,10 +132,14 @@ class StreamDataLoader():
world_size = paddle.distributed.get_world_size()
except Exception as e:
logger.warninig(e)
logger.warninig("can not get world_size using paddle.distributed.get_world_size(), use world_size=1")
assert(len(shardlist) >= world_size, "the length of shard list should >= number of gpus/xpus/...")
logger.warninig(
"can not get world_size using paddle.distributed.get_world_size(), use world_size=1"
)
assert (len(shardlist) >= world_size,
"the length of shard list should >= number of gpus/xpus/...")
update_n_iter_processes = int(max(min(len(shardlist)/world_size - 1, self.n_iter_processes), 0))
update_n_iter_processes = int(
max(min(len(shardlist) / world_size - 1, self.n_iter_processes), 0))
logger.info(f"update_n_iter_processes {update_n_iter_processes}")
if update_n_iter_processes != self.n_iter_processes:
self.n_iter_processes = update_n_iter_processes
@ -142,44 +147,50 @@ class StreamDataLoader():
if self.dist_sampler:
base_dataset = streamdata.DataPipeline(
streamdata.SimpleShardList(shardlist),
streamdata.split_by_node if train_mode else streamdata.placeholder(),
streamdata.SimpleShardList(shardlist), streamdata.split_by_node
if train_mode else streamdata.placeholder(),
streamdata.split_by_worker,
streamdata.tarfile_to_samples(streamdata.reraise_exception)
)
streamdata.tarfile_to_samples(streamdata.reraise_exception))
else:
base_dataset = streamdata.DataPipeline(
streamdata.SimpleShardList(shardlist),
streamdata.split_by_worker,
streamdata.tarfile_to_samples(streamdata.reraise_exception)
)
streamdata.tarfile_to_samples(streamdata.reraise_exception))
self.dataset = base_dataset.append_list(
streamdata.audio_tokenize(symbol_table),
streamdata.audio_data_filter(frame_shift=frame_shift, max_length=maxlen_in, min_length=minlen_in, token_max_length=maxlen_out, token_min_length=minlen_out),
streamdata.audio_data_filter(
frame_shift=frame_shift,
max_length=maxlen_in,
min_length=minlen_in,
token_max_length=maxlen_out,
token_min_length=minlen_out),
streamdata.audio_resample(resample_rate=resample_rate),
streamdata.audio_compute_fbank(num_mel_bins=num_mel_bins, frame_length=frame_length, frame_shift=frame_shift, dither=dither),
streamdata.audio_spec_aug(**augment_conf) if train_mode else streamdata.placeholder(), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80)
streamdata.audio_compute_fbank(
num_mel_bins=num_mel_bins,
frame_length=frame_length,
frame_shift=frame_shift,
dither=dither),
streamdata.audio_spec_aug(**augment_conf)
if train_mode else streamdata.placeholder(
), # num_t_mask=2, num_f_mask=2, max_t=40, max_f=30, max_w=80)
streamdata.shuffle(shuffle_size),
streamdata.sort(sort_size=sort_size),
streamdata.batched(batch_size),
streamdata.audio_padding(),
streamdata.audio_cmvn(cmvn_file)
)
streamdata.audio_cmvn(cmvn_file))
if paddle.__version__ >= '2.3.2':
self.loader = streamdata.WebLoader(
self.dataset,
num_workers=self.n_iter_processes,
prefetch_factor = self.prefetch_factor,
batch_size=None
)
prefetch_factor=self.prefetch_factor,
batch_size=None)
else:
self.loader = streamdata.WebLoader(
self.dataset,
num_workers=self.n_iter_processes,
batch_size=None
)
batch_size=None)
def __iter__(self):
return self.loader.__iter__()
@ -188,7 +199,9 @@ class StreamDataLoader():
return self.__iter__()
def __len__(self):
logger.info("Stream dataloader does not support calculate the length of the dataset")
logger.info(
"Stream dataloader does not support calculate the length of the dataset"
)
return -1
@ -358,30 +371,31 @@ class DataLoaderFactory():
config['maxlen_out'] = float('inf')
config['dist_sampler'] = False
else:
raise KeyError("not valid mode type!!, please input one of 'train, valid, test, align'")
return StreamDataLoader(
manifest_file=config.manifest,
train_mode=config.train_mode,
unit_type=config.unit_type,
preprocess_conf=config.preprocess_config,
batch_size=config.batch_size,
num_mel_bins=config.feat_dim,
frame_length=config.window_ms,
frame_shift=config.stride_ms,
dither=config.dither,
minlen_in=config.minlen_in,
maxlen_in=config.maxlen_in,
minlen_out=config.minlen_out,
maxlen_out=config.maxlen_out,
resample_rate=config.resample_rate,
shuffle_size=config.shuffle_size,
sort_size=config.sort_size,
n_iter_processes=config.num_workers,
prefetch_factor=config.prefetch_factor,
dist_sampler=config.dist_sampler,
cmvn_file=config.cmvn_file,
vocab_filepath=config.vocab_filepath,
raise KeyError(
"not valid mode type!!, please input one of 'train, valid, test, align'"
)
return StreamDataLoader(
manifest_file=config.manifest,
train_mode=config.train_mode,
unit_type=config.unit_type,
preprocess_conf=config.preprocess_config,
batch_size=config.batch_size,
num_mel_bins=config.feat_dim,
frame_length=config.window_ms,
frame_shift=config.stride_ms,
dither=config.dither,
minlen_in=config.minlen_in,
maxlen_in=config.maxlen_in,
minlen_out=config.minlen_out,
maxlen_out=config.maxlen_out,
resample_rate=config.resample_rate,
shuffle_size=config.shuffle_size,
sort_size=config.sort_size,
n_iter_processes=config.num_workers,
prefetch_factor=config.prefetch_factor,
dist_sampler=config.dist_sampler,
cmvn_file=config.cmvn_file,
vocab_filepath=config.vocab_filepath, )
else:
if mode == 'train':
config['manifest'] = config.train_manifest
@ -411,7 +425,7 @@ class DataLoaderFactory():
config['train_mode'] = False
config['sortagrad'] = False
config['batch_size'] = config.get('decode', dict()).get(
'decode_batch_size', 1)
'decode_batch_size', 1)
config['maxlen_in'] = float('inf')
config['maxlen_out'] = float('inf')
config['minibatches'] = 0
@ -427,7 +441,9 @@ class DataLoaderFactory():
config['dist_sampler'] = False
config['shortest_first'] = False
else:
raise KeyError("not valid mode type!!, please input one of 'train, valid, test, align'")
raise KeyError(
"not valid mode type!!, please input one of 'train, valid, test, align'"
)
return BatchDataLoader(
json_file=config.manifest,
@ -450,4 +466,3 @@ class DataLoaderFactory():
num_encs=config.num_encs,
dist_sampler=config.dist_sampler,
shortest_first=config.shortest_first)

@ -26,6 +26,8 @@ import paddle
from paddle import jit
from paddle import nn
from paddlespeech.audio.utils.tensor_utils import add_sos_eos
from paddlespeech.audio.utils.tensor_utils import th_accuracy
from paddlespeech.s2t.frontend.utility import IGNORE_ID
from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.modules.cmvn import GlobalCMVN
@ -38,8 +40,6 @@ from paddlespeech.s2t.modules.mask import subsequent_mask
from paddlespeech.s2t.utils import checkpoint
from paddlespeech.s2t.utils import layer_tools
from paddlespeech.s2t.utils.log import Log
from paddlespeech.audio.utils.tensor_utils import add_sos_eos
from paddlespeech.audio.utils.tensor_utils import th_accuracy
from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ["U2STModel", "U2STInferModel"]
@ -401,8 +401,8 @@ class U2STBaseModel(nn.Layer):
xs: paddle.Tensor,
offset: int,
required_cache_size: int,
att_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor = paddle.zeros([0, 0, 0, 0]),
att_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
cnn_cache: paddle.Tensor=paddle.zeros([0, 0, 0, 0]),
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
""" Export interface for c++ call, give input chunk xs, and return
output from time 0 to current chunk.
@ -435,8 +435,8 @@ class U2STBaseModel(nn.Layer):
paddle.Tensor: new conformer cnn cache required for next chunk, with
same shape as the original cnn_cache.
"""
return self.encoder.forward_chunk(
xs, offset, required_cache_size, att_cache, cnn_cache)
return self.encoder.forward_chunk(xs, offset, required_cache_size,
att_cache, cnn_cache)
# @jit.to_static
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:

Loading…
Cancel
Save