|
|
@ -208,8 +208,7 @@ class U2STTrainer(Trainer):
|
|
|
|
k.split(',')) == 2 else ""
|
|
|
|
k.split(',')) == 2 else ""
|
|
|
|
msg += ","
|
|
|
|
msg += ","
|
|
|
|
msg = msg[:-1] # remove the last ","
|
|
|
|
msg = msg[:-1] # remove the last ","
|
|
|
|
if (batch_index + 1
|
|
|
|
if (batch_index + 1) % self.config.log_interval == 0:
|
|
|
|
) % self.config.log_interval == 0:
|
|
|
|
|
|
|
|
logger.info(msg)
|
|
|
|
logger.info(msg)
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(e)
|
|
|
|
logger.error(e)
|
|
|
@ -260,7 +259,8 @@ class U2STTrainer(Trainer):
|
|
|
|
batch_frames_in=0,
|
|
|
|
batch_frames_in=0,
|
|
|
|
batch_frames_out=0,
|
|
|
|
batch_frames_out=0,
|
|
|
|
batch_frames_inout=0,
|
|
|
|
batch_frames_inout=0,
|
|
|
|
preprocess_conf=config.preprocess_config, # aug will be off when train_mode=False
|
|
|
|
preprocess_conf=config.
|
|
|
|
|
|
|
|
preprocess_config, # aug will be off when train_mode=False
|
|
|
|
n_iter_processes=config.num_workers,
|
|
|
|
n_iter_processes=config.num_workers,
|
|
|
|
subsampling_factor=1,
|
|
|
|
subsampling_factor=1,
|
|
|
|
load_aux_output=load_transcript,
|
|
|
|
load_aux_output=load_transcript,
|
|
|
@ -281,7 +281,8 @@ class U2STTrainer(Trainer):
|
|
|
|
batch_frames_in=0,
|
|
|
|
batch_frames_in=0,
|
|
|
|
batch_frames_out=0,
|
|
|
|
batch_frames_out=0,
|
|
|
|
batch_frames_inout=0,
|
|
|
|
batch_frames_inout=0,
|
|
|
|
preprocess_conf=config.preprocess_config, # aug will be off when train_mode=False
|
|
|
|
preprocess_conf=config.
|
|
|
|
|
|
|
|
preprocess_config, # aug will be off when train_mode=False
|
|
|
|
n_iter_processes=config.num_workers,
|
|
|
|
n_iter_processes=config.num_workers,
|
|
|
|
subsampling_factor=1,
|
|
|
|
subsampling_factor=1,
|
|
|
|
load_aux_output=load_transcript,
|
|
|
|
load_aux_output=load_transcript,
|
|
|
@ -290,7 +291,8 @@ class U2STTrainer(Trainer):
|
|
|
|
logger.info("Setup train/valid Dataloader!")
|
|
|
|
logger.info("Setup train/valid Dataloader!")
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
# test dataset, return raw text
|
|
|
|
# test dataset, return raw text
|
|
|
|
decode_batch_size = config.get('decode',dict()).get('decode_batch_size', 1)
|
|
|
|
decode_batch_size = config.get('decode', dict()).get(
|
|
|
|
|
|
|
|
'decode_batch_size', 1)
|
|
|
|
self.test_loader = BatchDataLoader(
|
|
|
|
self.test_loader = BatchDataLoader(
|
|
|
|
json_file=config.test_manifest,
|
|
|
|
json_file=config.test_manifest,
|
|
|
|
train_mode=False,
|
|
|
|
train_mode=False,
|
|
|
@ -305,7 +307,8 @@ class U2STTrainer(Trainer):
|
|
|
|
batch_frames_in=0,
|
|
|
|
batch_frames_in=0,
|
|
|
|
batch_frames_out=0,
|
|
|
|
batch_frames_out=0,
|
|
|
|
batch_frames_inout=0,
|
|
|
|
batch_frames_inout=0,
|
|
|
|
preprocess_conf=config.preprocess_config, # aug will be off when train_mode=False
|
|
|
|
preprocess_conf=config.
|
|
|
|
|
|
|
|
preprocess_config, # aug will be off when train_mode=False
|
|
|
|
n_iter_processes=config.num_workers,
|
|
|
|
n_iter_processes=config.num_workers,
|
|
|
|
subsampling_factor=1,
|
|
|
|
subsampling_factor=1,
|
|
|
|
num_encs=1,
|
|
|
|
num_encs=1,
|
|
|
|