|
|
|
@ -18,13 +18,11 @@ import time
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
from contextlib import nullcontext
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
import jsonlines
|
|
|
|
|
import numpy as np
|
|
|
|
|
import paddle
|
|
|
|
|
from paddle import distributed as dist
|
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
|
|
|
|
|
|
from paddlespeech.s2t.frontend.featurizer import TextFeaturizer
|
|
|
|
|
from paddlespeech.s2t.io.dataloader import BatchDataLoader
|
|
|
|
@ -208,8 +206,7 @@ class U2STTrainer(Trainer):
|
|
|
|
|
k.split(',')) == 2 else ""
|
|
|
|
|
msg += ","
|
|
|
|
|
msg = msg[:-1] # remove the last ","
|
|
|
|
|
if (batch_index + 1
|
|
|
|
|
) % self.config.log_interval == 0:
|
|
|
|
|
if (batch_index + 1) % self.config.log_interval == 0:
|
|
|
|
|
logger.info(msg)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(e)
|
|
|
|
@ -260,7 +257,8 @@ class U2STTrainer(Trainer):
|
|
|
|
|
batch_frames_in=0,
|
|
|
|
|
batch_frames_out=0,
|
|
|
|
|
batch_frames_inout=0,
|
|
|
|
|
preprocess_conf=config.preprocess_config, # aug will be off when train_mode=False
|
|
|
|
|
preprocess_conf=config.
|
|
|
|
|
preprocess_config, # aug will be off when train_mode=False
|
|
|
|
|
n_iter_processes=config.num_workers,
|
|
|
|
|
subsampling_factor=1,
|
|
|
|
|
load_aux_output=load_transcript,
|
|
|
|
@ -281,7 +279,8 @@ class U2STTrainer(Trainer):
|
|
|
|
|
batch_frames_in=0,
|
|
|
|
|
batch_frames_out=0,
|
|
|
|
|
batch_frames_inout=0,
|
|
|
|
|
preprocess_conf=config.preprocess_config, # aug will be off when train_mode=False
|
|
|
|
|
preprocess_conf=config.
|
|
|
|
|
preprocess_config, # aug will be off when train_mode=False
|
|
|
|
|
n_iter_processes=config.num_workers,
|
|
|
|
|
subsampling_factor=1,
|
|
|
|
|
load_aux_output=load_transcript,
|
|
|
|
@ -290,7 +289,8 @@ class U2STTrainer(Trainer):
|
|
|
|
|
logger.info("Setup train/valid Dataloader!")
|
|
|
|
|
else:
|
|
|
|
|
# test dataset, return raw text
|
|
|
|
|
decode_batch_size = config.get('decode',dict()).get('decode_batch_size', 1)
|
|
|
|
|
decode_batch_size = config.get('decode', dict()).get(
|
|
|
|
|
'decode_batch_size', 1)
|
|
|
|
|
self.test_loader = BatchDataLoader(
|
|
|
|
|
json_file=config.test_manifest,
|
|
|
|
|
train_mode=False,
|
|
|
|
@ -305,7 +305,8 @@ class U2STTrainer(Trainer):
|
|
|
|
|
batch_frames_in=0,
|
|
|
|
|
batch_frames_out=0,
|
|
|
|
|
batch_frames_inout=0,
|
|
|
|
|
preprocess_conf=config.preprocess_config, # aug will be off when train_mode=False
|
|
|
|
|
preprocess_conf=config.
|
|
|
|
|
preprocess_config, # aug will be off when train_mode=False
|
|
|
|
|
n_iter_processes=config.num_workers,
|
|
|
|
|
subsampling_factor=1,
|
|
|
|
|
num_encs=1,
|
|
|
|
|