|
|
|
@ -184,229 +184,79 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
|
self.logger.info("Setup model/optimizer/criterion!")
|
|
|
|
|
|
|
|
|
|
def compute_losses(self, inputs, outputs):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def test(self, test_reader):
|
|
|
|
|
'''Test the model.
|
|
|
|
|
|
|
|
|
|
:param exe:The executor of program.
|
|
|
|
|
:type exe: Executor
|
|
|
|
|
:param test_program: The program of test.
|
|
|
|
|
:type test_program: Program
|
|
|
|
|
:param test_reader: Reader of test.
|
|
|
|
|
:type test_reader: Reader
|
|
|
|
|
:return: An output unnormalized log probability.
|
|
|
|
|
:rtype: array
|
|
|
|
|
'''
|
|
|
|
|
test_reader.start()
|
|
|
|
|
epoch_loss = []
|
|
|
|
|
while True:
|
|
|
|
|
try:
|
|
|
|
|
each_loss = exe.run(
|
|
|
|
|
program=test_program,
|
|
|
|
|
fetch_list=fetch_list,
|
|
|
|
|
return_numpy=False)
|
|
|
|
|
epoch_loss.extend(np.array(each_loss[0]))
|
|
|
|
|
|
|
|
|
|
except fluid.core.EOFException:
|
|
|
|
|
test_reader.reset()
|
|
|
|
|
break
|
|
|
|
|
return np.mean(np.array(epoch_loss))
|
|
|
|
|
|
|
|
|
|
def train(self,
|
|
|
|
|
train_batch_reader,
|
|
|
|
|
dev_batch_reader,
|
|
|
|
|
feeding_dict,
|
|
|
|
|
learning_rate,
|
|
|
|
|
gradient_clipping,
|
|
|
|
|
num_epoch,
|
|
|
|
|
batch_size,
|
|
|
|
|
num_samples,
|
|
|
|
|
save_epoch=100,
|
|
|
|
|
num_iterations_print=100,
|
|
|
|
|
test_off=False):
|
|
|
|
|
"""Train the model.
|
|
|
|
|
|
|
|
|
|
:param train_batch_reader: Train data reader.
|
|
|
|
|
:type train_batch_reader: callable
|
|
|
|
|
:param dev_batch_reader: Validation data reader.
|
|
|
|
|
:type dev_batch_reader: callable
|
|
|
|
|
:param feeding_dict: Feeding is a map of field name and tuple index
|
|
|
|
|
of the data that reader returns.
|
|
|
|
|
:type feeding_dict: dict|list
|
|
|
|
|
:param learning_rate: Learning rate for ADAM optimizer.
|
|
|
|
|
:type learning_rate: float
|
|
|
|
|
:param gradient_clipping: Gradient clipping threshold.
|
|
|
|
|
:type gradient_clipping: float
|
|
|
|
|
:param num_epoch: Number of training epochs.
|
|
|
|
|
:type num_epoch: int
|
|
|
|
|
:param batch_size: Number of batch size.
|
|
|
|
|
:type batch_size: int
|
|
|
|
|
:param num_samples: The num of train samples.
|
|
|
|
|
:type num_samples: int
|
|
|
|
|
:param save_epoch: Number of training iterations for save checkpoint and params.
|
|
|
|
|
:type save_epoch: int
|
|
|
|
|
:param num_iterations_print: Number of training iterations for printing
|
|
|
|
|
a training loss.
|
|
|
|
|
:type num_iteratons_print: int
|
|
|
|
|
:param test_off: Turn off testing.
|
|
|
|
|
:type test_off: bool
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(self._place, fluid.CUDAPlace):
|
|
|
|
|
dev_count = fluid.core.get_cuda_device_count()
|
|
|
|
|
else:
|
|
|
|
|
dev_count = int(os.environ.get('CPU_NUM', 1))
|
|
|
|
|
|
|
|
|
|
# prepare the network
|
|
|
|
|
train_program = fluid.Program()
|
|
|
|
|
startup_prog = fluid.Program()
|
|
|
|
|
with fluid.program_guard(train_program, startup_prog):
|
|
|
|
|
with fluid.unique_name.guard():
|
|
|
|
|
train_reader, log_probs, ctc_loss = self.create_network()
|
|
|
|
|
# prepare optimizer
|
|
|
|
|
optimizer = fluid.optimizer.AdamOptimizer(
|
|
|
|
|
learning_rate=fluid.layers.exponential_decay(
|
|
|
|
|
learning_rate=learning_rate,
|
|
|
|
|
decay_steps=num_samples / batch_size / dev_count,
|
|
|
|
|
decay_rate=0.83,
|
|
|
|
|
staircase=True),
|
|
|
|
|
grad_clip=fluid.clip.GradientClipByGlobalNorm(
|
|
|
|
|
clip_norm=gradient_clipping))
|
|
|
|
|
optimizer.minimize(loss=ctc_loss)
|
|
|
|
|
|
|
|
|
|
test_prog = fluid.Program()
|
|
|
|
|
with fluid.program_guard(test_prog, startup_prog):
|
|
|
|
|
with fluid.unique_name.guard():
|
|
|
|
|
test_reader, _, ctc_loss = self.create_network()
|
|
|
|
|
|
|
|
|
|
test_prog = test_prog.clone(for_test=True)
|
|
|
|
|
|
|
|
|
|
exe = fluid.Executor(self._place)
|
|
|
|
|
exe.run(startup_prog)
|
|
|
|
|
|
|
|
|
|
# init from some pretrain models, to better solve the current task
|
|
|
|
|
pre_epoch = 0
|
|
|
|
|
if self._init_from_pretrained_model:
|
|
|
|
|
pre_epoch = self.init_from_pretrained_model(exe, train_program)
|
|
|
|
|
|
|
|
|
|
train_reader.set_batch_generator(train_batch_reader)
|
|
|
|
|
test_reader.set_batch_generator(dev_batch_reader)
|
|
|
|
|
|
|
|
|
|
# run train
|
|
|
|
|
for epoch_id in range(num_epoch):
|
|
|
|
|
train_reader.start()
|
|
|
|
|
epoch_loss = []
|
|
|
|
|
time_begin = time.time()
|
|
|
|
|
batch_id = 0
|
|
|
|
|
step = 0
|
|
|
|
|
while True:
|
|
|
|
|
try:
|
|
|
|
|
fetch_list = [ctc_loss.name]
|
|
|
|
|
|
|
|
|
|
if batch_id % num_iterations_print == 0:
|
|
|
|
|
fetch = exe.run(
|
|
|
|
|
program=compiled_prog,
|
|
|
|
|
fetch_list=fetch_list,
|
|
|
|
|
return_numpy=False)
|
|
|
|
|
each_loss = fetch[0]
|
|
|
|
|
epoch_loss.extend(np.array(each_loss[0]) / batch_size)
|
|
|
|
|
|
|
|
|
|
print("epoch: %d, batch: %d, train loss: %f\n" %
|
|
|
|
|
(epoch_id, batch_id,
|
|
|
|
|
np.mean(each_loss[0]) / batch_size))
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
each_loss = exe.run(
|
|
|
|
|
program=compiled_prog,
|
|
|
|
|
fetch_list=[],
|
|
|
|
|
return_numpy=False)
|
|
|
|
|
|
|
|
|
|
batch_id = batch_id + 1
|
|
|
|
|
except fluid.core.EOFException:
|
|
|
|
|
train_reader.reset()
|
|
|
|
|
break
|
|
|
|
|
time_end = time.time()
|
|
|
|
|
used_time = time_end - time_begin
|
|
|
|
|
if test_off:
|
|
|
|
|
print("\n--------Time: %f sec, epoch: %d, train loss: %f\n" %
|
|
|
|
|
(used_time, epoch_id, np.mean(np.array(epoch_loss))))
|
|
|
|
|
else:
|
|
|
|
|
print('\n----------Begin test...')
|
|
|
|
|
test_loss = self.test(
|
|
|
|
|
exe,
|
|
|
|
|
dev_batch_reader=dev_batch_reader,
|
|
|
|
|
test_program=test_prog,
|
|
|
|
|
test_reader=test_reader,
|
|
|
|
|
fetch_list=[ctc_loss])
|
|
|
|
|
print(
|
|
|
|
|
"--------Time: %f sec, epoch: %d, train loss: %f, test loss: %f"
|
|
|
|
|
% (used_time, epoch_id + pre_epoch,
|
|
|
|
|
np.mean(np.array(epoch_loss)), test_loss / batch_size))
|
|
|
|
|
if (epoch_id + 1) % save_epoch == 0:
|
|
|
|
|
self.save_param(exe, train_program,
|
|
|
|
|
"epoch_" + str(epoch_id + pre_epoch))
|
|
|
|
|
|
|
|
|
|
self.save_param(exe, train_program, "step_final")
|
|
|
|
|
|
|
|
|
|
print("\n------------Training finished!!!-------------")
|
|
|
|
|
|
|
|
|
|
def infer_batch_probs(self, infer_data, feeding_dict):
|
|
|
|
|
_, texts, logits_len, texts_len = inputs
|
|
|
|
|
logits = outputs
|
|
|
|
|
loss = self.criterion(logits, texts, logits_len, texts_len)
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
def train_batch(self):
|
|
|
|
|
start = time.time()
|
|
|
|
|
batch = self.read_batch()
|
|
|
|
|
data_loader_time = time.time() - start
|
|
|
|
|
|
|
|
|
|
self.optimizer.clear_grad()
|
|
|
|
|
self.model.train()
|
|
|
|
|
audio, text, audio_len, text_len = batch
|
|
|
|
|
outputs = self.model(audio, text, audio_len, text_len)
|
|
|
|
|
loss = self.compute_losses(batch, outputs)
|
|
|
|
|
loss.backward()
|
|
|
|
|
self.optimizer.step()
|
|
|
|
|
iteration_time = time.time() - start
|
|
|
|
|
|
|
|
|
|
losses_np = {'loss': float(loss)}
|
|
|
|
|
msg = "Rank: {}, ".format(dist.get_rank())
|
|
|
|
|
msg += "epoch: {}, ".format(self.epoch)
|
|
|
|
|
msg += "step: {}, ".format(self.iteration)
|
|
|
|
|
|
|
|
|
|
msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time,
|
|
|
|
|
iteration_time)
|
|
|
|
|
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
|
|
|
|
for k, v in losses_np.items())
|
|
|
|
|
self.logger.info(msg)
|
|
|
|
|
|
|
|
|
|
if dist.get_rank() == 0:
|
|
|
|
|
for k, v in losses_np.items():
|
|
|
|
|
self.visualizer.add_scalar("train/{}".format(k), v,
|
|
|
|
|
self.iteration)
|
|
|
|
|
|
|
|
|
|
@mp_tools.rank_zero_only
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
def valid(self):
|
|
|
|
|
valid_losses = defaultdict(list)
|
|
|
|
|
for i, batch in enumerate(self.valid_loader):
|
|
|
|
|
audio, text, audio_len, text_len = batch
|
|
|
|
|
outputs = self.model(audio, text, audio_len, text_len)
|
|
|
|
|
losses = self.compute_losses(batch, outputs)
|
|
|
|
|
|
|
|
|
|
valid_losses['val_loss'].append(float(v))
|
|
|
|
|
|
|
|
|
|
# write visual log
|
|
|
|
|
valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}
|
|
|
|
|
|
|
|
|
|
# logging
|
|
|
|
|
msg = "Valid: "
|
|
|
|
|
msg += "step: {}, ".format(self.iteration)
|
|
|
|
|
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
|
|
|
|
for k, v in valid_losses.items())
|
|
|
|
|
self.logger.info(msg)
|
|
|
|
|
|
|
|
|
|
for k, v in valid_losses.items():
|
|
|
|
|
self.visualizer.add_scalar("valid/{}".foramt(k), v, self.iteration)
|
|
|
|
|
|
|
|
|
|
def infer_batch_probs(self, infer_data):
|
|
|
|
|
"""Infer the prob matrices for a batch of speech utterances.
|
|
|
|
|
:param infer_data: List of utterances to infer, with each utterance
|
|
|
|
|
consisting of a tuple of audio features and
|
|
|
|
|
transcription text (empty string).
|
|
|
|
|
:type infer_data: list
|
|
|
|
|
:param feeding_dict: Feeding is a map of field name and tuple index
|
|
|
|
|
of the data that reader returns.
|
|
|
|
|
:type feeding_dict: dict|list
|
|
|
|
|
:return: List of 2-D probability matrix, and each consists of prob
|
|
|
|
|
vectors for one speech utterancce.
|
|
|
|
|
:rtype: List of matrix
|
|
|
|
|
"""
|
|
|
|
|
# define inferer
|
|
|
|
|
infer_program = fluid.Program()
|
|
|
|
|
startup_prog = fluid.Program()
|
|
|
|
|
|
|
|
|
|
# prepare the network
|
|
|
|
|
with fluid.program_guard(infer_program, startup_prog):
|
|
|
|
|
with fluid.unique_name.guard():
|
|
|
|
|
feeder, log_probs, _ = self.create_network(is_infer=True)
|
|
|
|
|
|
|
|
|
|
infer_program = infer_program.clone(for_test=True)
|
|
|
|
|
exe = fluid.Executor(self._place)
|
|
|
|
|
exe.run(startup_prog)
|
|
|
|
|
|
|
|
|
|
# init param from pretrained_model
|
|
|
|
|
if not self._init_from_pretrained_model:
|
|
|
|
|
exit("No pretrain model file path!")
|
|
|
|
|
self.init_from_pretrained_model(exe, infer_program)
|
|
|
|
|
|
|
|
|
|
infer_results = []
|
|
|
|
|
time_begin = time.time()
|
|
|
|
|
|
|
|
|
|
# run inference
|
|
|
|
|
for i in range(infer_data[0].shape[0]):
|
|
|
|
|
each_log_probs = exe.run(
|
|
|
|
|
program=infer_program,
|
|
|
|
|
feed=feeder.feed(
|
|
|
|
|
[[infer_data[0][i], infer_data[2][i], infer_data[3][i]]]),
|
|
|
|
|
fetch_list=[log_probs],
|
|
|
|
|
return_numpy=False)
|
|
|
|
|
infer_results.extend(np.array(each_log_probs[0]))
|
|
|
|
|
|
|
|
|
|
# slice result
|
|
|
|
|
infer_results = np.array(infer_results)
|
|
|
|
|
seq_len = (infer_data[2] - 1) // 3 + 1
|
|
|
|
|
|
|
|
|
|
start_pos = [0] * (infer_data[0].shape[0] + 1)
|
|
|
|
|
for i in range(infer_data[0].shape[0]):
|
|
|
|
|
start_pos[i + 1] = start_pos[i] + seq_len[i][0]
|
|
|
|
|
probs_split = [
|
|
|
|
|
infer_results[start_pos[i]:start_pos[i + 1]]
|
|
|
|
|
for i in range(0, infer_data[0].shape[0])
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
return probs_split
|
|
|
|
|
self.model.eval()
|
|
|
|
|
audio, text, audio_len, text_len = infer_data
|
|
|
|
|
logits, probs = self.model.predict(audio, audio_len)
|
|
|
|
|
return probs
|
|
|
|
|
|
|
|
|
|
def decode_batch_greedy(self, probs_split, vocab_list):
|
|
|
|
|
"""Decode by best path for a batch of probs matrix input.
|
|
|
|
|