|
|
|
@ -25,6 +25,8 @@ import paddle
|
|
|
|
|
from paddle import distributed as dist
|
|
|
|
|
from yacs.config import CfgNode
|
|
|
|
|
|
|
|
|
|
from deepspeech.frontend.featurizer import TextFeaturizer
|
|
|
|
|
from deepspeech.frontend.utility import load_dict
|
|
|
|
|
from deepspeech.io.dataloader import BatchDataLoader
|
|
|
|
|
from deepspeech.models.u2 import U2Model
|
|
|
|
|
from deepspeech.training.optimizer import OptimizerFactory
|
|
|
|
@ -80,8 +82,8 @@ class U2Trainer(Trainer):
|
|
|
|
|
def train_batch(self, batch_index, batch_data, msg):
|
|
|
|
|
train_conf = self.config.training
|
|
|
|
|
start = time.time()
|
|
|
|
|
utt, audio, audio_len, text, text_len = batch_data
|
|
|
|
|
|
|
|
|
|
utt, audio, audio_len, text, text_len = batch_data
|
|
|
|
|
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
|
|
|
|
|
text_len)
|
|
|
|
|
# loss div by `batch_size * accum_grad`
|
|
|
|
@ -124,6 +126,7 @@ class U2Trainer(Trainer):
|
|
|
|
|
valid_losses = defaultdict(list)
|
|
|
|
|
num_seen_utts = 1
|
|
|
|
|
total_loss = 0.0
|
|
|
|
|
|
|
|
|
|
for i, batch in enumerate(self.valid_loader):
|
|
|
|
|
utt, audio, audio_len, text, text_len = batch
|
|
|
|
|
loss, attention_loss, ctc_loss = self.model(audio, audio_len, text,
|
|
|
|
@ -305,10 +308,8 @@ class U2Trainer(Trainer):
|
|
|
|
|
model_conf.output_dim = self.train_loader.vocab_size
|
|
|
|
|
model_conf.freeze()
|
|
|
|
|
model = U2Model.from_config(model_conf)
|
|
|
|
|
|
|
|
|
|
if self.parallel:
|
|
|
|
|
model = paddle.DataParallel(model)
|
|
|
|
|
|
|
|
|
|
logger.info(f"{model}")
|
|
|
|
|
layer_tools.print_params(model, logger.info)
|
|
|
|
|
|
|
|
|
@ -379,13 +380,13 @@ class U2Tester(U2Trainer):
|
|
|
|
|
def __init__(self, config, args):
|
|
|
|
|
super().__init__(config, args)
|
|
|
|
|
|
|
|
|
|
def ordid2token(self, texts, texts_len):
|
|
|
|
|
def id2token(self, texts, texts_len, text_feature):
|
|
|
|
|
""" ord() id to chr() chr """
|
|
|
|
|
trans = []
|
|
|
|
|
for text, n in zip(texts, texts_len):
|
|
|
|
|
n = n.numpy().item()
|
|
|
|
|
ids = text[:n]
|
|
|
|
|
trans.append(''.join([chr(i) for i in ids]))
|
|
|
|
|
trans.append(text_feature.defeaturize(ids.numpy().tolist()))
|
|
|
|
|
return trans
|
|
|
|
|
|
|
|
|
|
def compute_metrics(self,
|
|
|
|
@ -401,8 +402,11 @@ class U2Tester(U2Trainer):
|
|
|
|
|
error_rate_func = error_rate.cer if cfg.error_rate_type == 'cer' else error_rate.wer
|
|
|
|
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
text_feature = self.test_loader.collate_fn.text_feature
|
|
|
|
|
target_transcripts = self.ordid2token(texts, texts_len)
|
|
|
|
|
text_feature = TextFeaturizer(
|
|
|
|
|
unit_type=self.config.collator.unit_type,
|
|
|
|
|
vocab_filepath=self.config.collator.vocab_filepath,
|
|
|
|
|
spm_model_prefix=self.config.collator.spm_model_prefix)
|
|
|
|
|
target_transcripts = self.id2token(texts, texts_len, text_feature)
|
|
|
|
|
result_transcripts = self.model.decode(
|
|
|
|
|
audio,
|
|
|
|
|
audio_len,
|
|
|
|
@ -450,7 +454,7 @@ class U2Tester(U2Trainer):
|
|
|
|
|
self.model.eval()
|
|
|
|
|
logger.info(f"Test Total Examples: {len(self.test_loader.dataset)}")
|
|
|
|
|
|
|
|
|
|
stride_ms = self.test_loader.collate_fn.stride_ms
|
|
|
|
|
stride_ms = self.config.collator.stride_ms
|
|
|
|
|
error_rate_type = None
|
|
|
|
|
errors_sum, len_refs, num_ins = 0.0, 0, 0
|
|
|
|
|
num_frames = 0.0
|
|
|
|
@ -525,8 +529,9 @@ class U2Tester(U2Trainer):
|
|
|
|
|
self.model.eval()
|
|
|
|
|
logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
|
|
|
|
|
|
|
|
|
|
stride_ms = self.config.collate.stride_ms
|
|
|
|
|
token_dict = self.align_loader.collate_fn.vocab_list
|
|
|
|
|
stride_ms = self.config.collater.stride_ms
|
|
|
|
|
token_dict = self.args.char_list
|
|
|
|
|
|
|
|
|
|
with open(self.args.result_file, 'w') as fout:
|
|
|
|
|
# one example in batch
|
|
|
|
|
for i, batch in enumerate(self.align_loader):
|
|
|
|
@ -613,6 +618,11 @@ class U2Tester(U2Trainer):
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
sys.exit(-1)
|
|
|
|
|
|
|
|
|
|
def setup_dict(self):
|
|
|
|
|
# load dictionary for debug log
|
|
|
|
|
self.args.char_list = load_dict(self.args.dict_path,
|
|
|
|
|
"maskctc" in self.args.model_name)
|
|
|
|
|
|
|
|
|
|
def setup(self):
|
|
|
|
|
"""Setup the experiment.
|
|
|
|
|
"""
|
|
|
|
@ -624,6 +634,8 @@ class U2Tester(U2Trainer):
|
|
|
|
|
self.setup_dataloader()
|
|
|
|
|
self.setup_model()
|
|
|
|
|
|
|
|
|
|
self.setup_dict()
|
|
|
|
|
|
|
|
|
|
self.iteration = 0
|
|
|
|
|
self.epoch = 0
|
|
|
|
|
|
|
|
|
|