|
|
|
@ -1,15 +1,22 @@
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
import copy
|
|
|
|
|
import json
|
|
|
|
|
import os
|
|
|
|
|
import shutil
|
|
|
|
|
import tempfile
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
from . import extension
|
|
|
|
|
from ..updaters.trainer import Trainer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PlotAttentionReport(extension.Extension):
|
|
|
|
@ -37,20 +44,19 @@ class PlotAttentionReport(extension.Extension):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
att_vis_fn,
|
|
|
|
|
data,
|
|
|
|
|
outdir,
|
|
|
|
|
converter,
|
|
|
|
|
transform,
|
|
|
|
|
device,
|
|
|
|
|
reverse=False,
|
|
|
|
|
ikey="input",
|
|
|
|
|
iaxis=0,
|
|
|
|
|
okey="output",
|
|
|
|
|
oaxis=0,
|
|
|
|
|
subsampling_factor=1,
|
|
|
|
|
):
|
|
|
|
|
self,
|
|
|
|
|
att_vis_fn,
|
|
|
|
|
data,
|
|
|
|
|
outdir,
|
|
|
|
|
converter,
|
|
|
|
|
transform,
|
|
|
|
|
device,
|
|
|
|
|
reverse=False,
|
|
|
|
|
ikey="input",
|
|
|
|
|
iaxis=0,
|
|
|
|
|
okey="output",
|
|
|
|
|
oaxis=0,
|
|
|
|
|
subsampling_factor=1, ):
|
|
|
|
|
self.att_vis_fn = att_vis_fn
|
|
|
|
|
self.data = copy.deepcopy(data)
|
|
|
|
|
self.data_dict = {k: v for k, v in copy.deepcopy(data)}
|
|
|
|
@ -77,44 +83,30 @@ class PlotAttentionReport(extension.Extension):
|
|
|
|
|
for i in range(num_encs):
|
|
|
|
|
for idx, att_w in enumerate(att_ws[i]):
|
|
|
|
|
filename = "%s/%s.ep.{.updater.epoch}.att%d.png" % (
|
|
|
|
|
self.outdir,
|
|
|
|
|
uttid_list[idx],
|
|
|
|
|
i + 1,
|
|
|
|
|
)
|
|
|
|
|
self.outdir, uttid_list[idx], i + 1, )
|
|
|
|
|
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
|
|
|
|
np_filename = "%s/%s.ep.{.updater.epoch}.att%d.npy" % (
|
|
|
|
|
self.outdir,
|
|
|
|
|
uttid_list[idx],
|
|
|
|
|
i + 1,
|
|
|
|
|
)
|
|
|
|
|
self.outdir, uttid_list[idx], i + 1, )
|
|
|
|
|
np.save(np_filename.format(trainer), att_w)
|
|
|
|
|
self._plot_and_save_attention(att_w, filename.format(trainer))
|
|
|
|
|
self._plot_and_save_attention(att_w,
|
|
|
|
|
filename.format(trainer))
|
|
|
|
|
# han
|
|
|
|
|
for idx, att_w in enumerate(att_ws[num_encs]):
|
|
|
|
|
filename = "%s/%s.ep.{.updater.epoch}.han.png" % (
|
|
|
|
|
self.outdir,
|
|
|
|
|
uttid_list[idx],
|
|
|
|
|
)
|
|
|
|
|
self.outdir, uttid_list[idx], )
|
|
|
|
|
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
|
|
|
|
np_filename = "%s/%s.ep.{.updater.epoch}.han.npy" % (
|
|
|
|
|
self.outdir,
|
|
|
|
|
uttid_list[idx],
|
|
|
|
|
)
|
|
|
|
|
self.outdir, uttid_list[idx], )
|
|
|
|
|
np.save(np_filename.format(trainer), att_w)
|
|
|
|
|
self._plot_and_save_attention(
|
|
|
|
|
att_w, filename.format(trainer), han_mode=True
|
|
|
|
|
)
|
|
|
|
|
att_w, filename.format(trainer), han_mode=True)
|
|
|
|
|
else:
|
|
|
|
|
for idx, att_w in enumerate(att_ws):
|
|
|
|
|
filename = "%s/%s.ep.{.updater.epoch}.png" % (
|
|
|
|
|
self.outdir,
|
|
|
|
|
uttid_list[idx],
|
|
|
|
|
)
|
|
|
|
|
filename = "%s/%s.ep.{.updater.epoch}.png" % (self.outdir,
|
|
|
|
|
uttid_list[idx], )
|
|
|
|
|
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
|
|
|
|
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
|
|
|
|
|
self.outdir,
|
|
|
|
|
uttid_list[idx],
|
|
|
|
|
)
|
|
|
|
|
self.outdir, uttid_list[idx], )
|
|
|
|
|
np.save(np_filename.format(trainer), att_w)
|
|
|
|
|
self._plot_and_save_attention(att_w, filename.format(trainer))
|
|
|
|
|
|
|
|
|
@ -131,8 +123,7 @@ class PlotAttentionReport(extension.Extension):
|
|
|
|
|
logger.add_figure(
|
|
|
|
|
"%s_att%d" % (uttid_list[idx], i + 1),
|
|
|
|
|
plot.gcf(),
|
|
|
|
|
step,
|
|
|
|
|
)
|
|
|
|
|
step, )
|
|
|
|
|
# han
|
|
|
|
|
for idx, att_w in enumerate(att_ws[num_encs]):
|
|
|
|
|
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
|
|
|
@ -140,8 +131,7 @@ class PlotAttentionReport(extension.Extension):
|
|
|
|
|
logger.add_figure(
|
|
|
|
|
"%s_han" % (uttid_list[idx]),
|
|
|
|
|
plot.gcf(),
|
|
|
|
|
step,
|
|
|
|
|
)
|
|
|
|
|
step, )
|
|
|
|
|
else:
|
|
|
|
|
for idx, att_w in enumerate(att_ws):
|
|
|
|
|
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
|
|
|
@ -286,20 +276,19 @@ class PlotCTCReport(extension.Extension):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
ctc_vis_fn,
|
|
|
|
|
data,
|
|
|
|
|
outdir,
|
|
|
|
|
converter,
|
|
|
|
|
transform,
|
|
|
|
|
device,
|
|
|
|
|
reverse=False,
|
|
|
|
|
ikey="input",
|
|
|
|
|
iaxis=0,
|
|
|
|
|
okey="output",
|
|
|
|
|
oaxis=0,
|
|
|
|
|
subsampling_factor=1,
|
|
|
|
|
):
|
|
|
|
|
self,
|
|
|
|
|
ctc_vis_fn,
|
|
|
|
|
data,
|
|
|
|
|
outdir,
|
|
|
|
|
converter,
|
|
|
|
|
transform,
|
|
|
|
|
device,
|
|
|
|
|
reverse=False,
|
|
|
|
|
ikey="input",
|
|
|
|
|
iaxis=0,
|
|
|
|
|
okey="output",
|
|
|
|
|
oaxis=0,
|
|
|
|
|
subsampling_factor=1, ):
|
|
|
|
|
self.ctc_vis_fn = ctc_vis_fn
|
|
|
|
|
self.data = copy.deepcopy(data)
|
|
|
|
|
self.data_dict = {k: v for k, v in copy.deepcopy(data)}
|
|
|
|
@ -325,29 +314,19 @@ class PlotCTCReport(extension.Extension):
|
|
|
|
|
for i in range(num_encs):
|
|
|
|
|
for idx, ctc_prob in enumerate(ctc_probs[i]):
|
|
|
|
|
filename = "%s/%s.ep.{.updater.epoch}.ctc%d.png" % (
|
|
|
|
|
self.outdir,
|
|
|
|
|
uttid_list[idx],
|
|
|
|
|
i + 1,
|
|
|
|
|
)
|
|
|
|
|
self.outdir, uttid_list[idx], i + 1, )
|
|
|
|
|
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
|
|
|
|
|
np_filename = "%s/%s.ep.{.updater.epoch}.ctc%d.npy" % (
|
|
|
|
|
self.outdir,
|
|
|
|
|
uttid_list[idx],
|
|
|
|
|
i + 1,
|
|
|
|
|
)
|
|
|
|
|
self.outdir, uttid_list[idx], i + 1, )
|
|
|
|
|
np.save(np_filename.format(trainer), ctc_prob)
|
|
|
|
|
self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
|
|
|
|
|
else:
|
|
|
|
|
for idx, ctc_prob in enumerate(ctc_probs):
|
|
|
|
|
filename = "%s/%s.ep.{.updater.epoch}.png" % (
|
|
|
|
|
self.outdir,
|
|
|
|
|
uttid_list[idx],
|
|
|
|
|
)
|
|
|
|
|
filename = "%s/%s.ep.{.updater.epoch}.png" % (self.outdir,
|
|
|
|
|
uttid_list[idx], )
|
|
|
|
|
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
|
|
|
|
|
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
|
|
|
|
|
self.outdir,
|
|
|
|
|
uttid_list[idx],
|
|
|
|
|
)
|
|
|
|
|
self.outdir, uttid_list[idx], )
|
|
|
|
|
np.save(np_filename.format(trainer), ctc_prob)
|
|
|
|
|
self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
|
|
|
|
|
|
|
|
|
@ -363,8 +342,7 @@ class PlotCTCReport(extension.Extension):
|
|
|
|
|
logger.add_figure(
|
|
|
|
|
"%s_ctc%d" % (uttid_list[idx], i + 1),
|
|
|
|
|
plot.gcf(),
|
|
|
|
|
step,
|
|
|
|
|
)
|
|
|
|
|
step, )
|
|
|
|
|
else:
|
|
|
|
|
for idx, ctc_prob in enumerate(ctc_probs):
|
|
|
|
|
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
|
|
|
|
@ -420,8 +398,11 @@ class PlotCTCReport(extension.Extension):
|
|
|
|
|
for idx in set(topk_ids.reshape(-1).tolist()):
|
|
|
|
|
if idx == 0:
|
|
|
|
|
plt.plot(
|
|
|
|
|
times_probs, ctc_prob[:, 0], ":", label="<blank>", color="grey"
|
|
|
|
|
)
|
|
|
|
|
times_probs,
|
|
|
|
|
ctc_prob[:, 0],
|
|
|
|
|
":",
|
|
|
|
|
label="<blank>",
|
|
|
|
|
color="grey")
|
|
|
|
|
else:
|
|
|
|
|
plt.plot(times_probs, ctc_prob[:, idx])
|
|
|
|
|
plt.xlabel(u"Input [frame]", fontsize=12)
|
|
|
|
@ -434,4 +415,4 @@ class PlotCTCReport(extension.Extension):
|
|
|
|
|
def _plot_and_save_ctc(self, ctc_prob, filename):
|
|
|
|
|
plt = self.draw_ctc_plot(ctc_prob)
|
|
|
|
|
plt.savefig(filename)
|
|
|
|
|
plt.close()
|
|
|
|
|
plt.close()
|
|
|
|
|