commit
9e2773dffb
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
@ -0,0 +1,75 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""ST Interface module."""
|
||||||
|
from .asr_interface import ASRInterface
|
||||||
|
from deepspeech.utils.dynamic_import import dynamic_import
|
||||||
|
|
||||||
|
|
||||||
|
class STInterface(ASRInterface):
|
||||||
|
"""ST Interface model implementation.
|
||||||
|
|
||||||
|
NOTE: This class is inherited from ASRInterface to enable joint translation
|
||||||
|
and recognition when performing multi-task learning with the ASR task.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def translate(self,
|
||||||
|
x,
|
||||||
|
trans_args,
|
||||||
|
char_list=None,
|
||||||
|
rnnlm=None,
|
||||||
|
ensemble_models=[]):
|
||||||
|
"""Recognize x for evaluation.
|
||||||
|
|
||||||
|
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
|
||||||
|
:param namespace trans_args: argment namespace contraining options
|
||||||
|
:param list char_list: list of characters
|
||||||
|
:param paddle.nn.Layer rnnlm: language model module
|
||||||
|
:return: N-best decoding results
|
||||||
|
:rtype: list
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("translate method is not implemented")
|
||||||
|
|
||||||
|
def translate_batch(self, x, trans_args, char_list=None, rnnlm=None):
|
||||||
|
"""Beam search implementation for batch.
|
||||||
|
|
||||||
|
:param paddle.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
|
||||||
|
:param namespace trans_args: argument namespace containing options
|
||||||
|
:param list char_list: list of characters
|
||||||
|
:param paddle.nn.Layer rnnlm: language model module
|
||||||
|
:return: N-best decoding results
|
||||||
|
:rtype: list
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Batch decoding is not supported yet.")
|
||||||
|
|
||||||
|
|
||||||
|
predefined_st = {
|
||||||
|
"transformer": "deepspeech.models.u2_st:U2STModel",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_import_st(module):
|
||||||
|
"""Import ST models dynamically.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (str): module_name:class_name or alias in `predefined_st`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
type: ST class
|
||||||
|
|
||||||
|
"""
|
||||||
|
model_class = dynamic_import(module, predefined_st)
|
||||||
|
assert issubclass(model_class,
|
||||||
|
STInterface), f"{module} does not implement STInterface"
|
||||||
|
return model_class
|
@ -0,0 +1,15 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from .u2_st import U2STInferModel
|
||||||
|
from .u2_st import U2STModel
|
@ -0,0 +1,418 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import copy
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from . import extension
|
||||||
|
|
||||||
|
|
||||||
|
class PlotAttentionReport(extension.Extension):
|
||||||
|
"""Plot attention reporter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
att_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_attentions):
|
||||||
|
Function of attention visualization.
|
||||||
|
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
|
||||||
|
outdir (str): Directory to save figures.
|
||||||
|
converter (espnet.asr.*_backend.asr.CustomConverter):
|
||||||
|
Function to convert data.
|
||||||
|
device (int | torch.device): Device.
|
||||||
|
reverse (bool): If True, input and output length are reversed.
|
||||||
|
ikey (str): Key to access input
|
||||||
|
(for ASR/ST ikey="input", for MT ikey="output".)
|
||||||
|
iaxis (int): Dimension to access input
|
||||||
|
(for ASR/ST iaxis=0, for MT iaxis=1.)
|
||||||
|
okey (str): Key to access output
|
||||||
|
(for ASR/ST okey="input", MT okay="output".)
|
||||||
|
oaxis (int): Dimension to access output
|
||||||
|
(for ASR/ST oaxis=0, for MT oaxis=0.)
|
||||||
|
subsampling_factor (int): subsampling factor in encoder
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
att_vis_fn,
|
||||||
|
data,
|
||||||
|
outdir,
|
||||||
|
converter,
|
||||||
|
transform,
|
||||||
|
device,
|
||||||
|
reverse=False,
|
||||||
|
ikey="input",
|
||||||
|
iaxis=0,
|
||||||
|
okey="output",
|
||||||
|
oaxis=0,
|
||||||
|
subsampling_factor=1, ):
|
||||||
|
self.att_vis_fn = att_vis_fn
|
||||||
|
self.data = copy.deepcopy(data)
|
||||||
|
self.data_dict = {k: v for k, v in copy.deepcopy(data)}
|
||||||
|
# key is utterance ID
|
||||||
|
self.outdir = outdir
|
||||||
|
self.converter = converter
|
||||||
|
self.transform = transform
|
||||||
|
self.device = device
|
||||||
|
self.reverse = reverse
|
||||||
|
self.ikey = ikey
|
||||||
|
self.iaxis = iaxis
|
||||||
|
self.okey = okey
|
||||||
|
self.oaxis = oaxis
|
||||||
|
self.factor = subsampling_factor
|
||||||
|
if not os.path.exists(self.outdir):
|
||||||
|
os.makedirs(self.outdir)
|
||||||
|
|
||||||
|
def __call__(self, trainer):
|
||||||
|
"""Plot and save image file of att_ws matrix."""
|
||||||
|
att_ws, uttid_list = self.get_attention_weights()
|
||||||
|
if isinstance(att_ws, list): # multi-encoder case
|
||||||
|
num_encs = len(att_ws) - 1
|
||||||
|
# atts
|
||||||
|
for i in range(num_encs):
|
||||||
|
for idx, att_w in enumerate(att_ws[i]):
|
||||||
|
filename = "%s/%s.ep.{.updater.epoch}.att%d.png" % (
|
||||||
|
self.outdir, uttid_list[idx], i + 1, )
|
||||||
|
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
||||||
|
np_filename = "%s/%s.ep.{.updater.epoch}.att%d.npy" % (
|
||||||
|
self.outdir, uttid_list[idx], i + 1, )
|
||||||
|
np.save(np_filename.format(trainer), att_w)
|
||||||
|
self._plot_and_save_attention(att_w,
|
||||||
|
filename.format(trainer))
|
||||||
|
# han
|
||||||
|
for idx, att_w in enumerate(att_ws[num_encs]):
|
||||||
|
filename = "%s/%s.ep.{.updater.epoch}.han.png" % (
|
||||||
|
self.outdir, uttid_list[idx], )
|
||||||
|
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
||||||
|
np_filename = "%s/%s.ep.{.updater.epoch}.han.npy" % (
|
||||||
|
self.outdir, uttid_list[idx], )
|
||||||
|
np.save(np_filename.format(trainer), att_w)
|
||||||
|
self._plot_and_save_attention(
|
||||||
|
att_w, filename.format(trainer), han_mode=True)
|
||||||
|
else:
|
||||||
|
for idx, att_w in enumerate(att_ws):
|
||||||
|
filename = "%s/%s.ep.{.updater.epoch}.png" % (self.outdir,
|
||||||
|
uttid_list[idx], )
|
||||||
|
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
||||||
|
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
|
||||||
|
self.outdir, uttid_list[idx], )
|
||||||
|
np.save(np_filename.format(trainer), att_w)
|
||||||
|
self._plot_and_save_attention(att_w, filename.format(trainer))
|
||||||
|
|
||||||
|
def log_attentions(self, logger, step):
|
||||||
|
"""Add image files of att_ws matrix to the tensorboard."""
|
||||||
|
att_ws, uttid_list = self.get_attention_weights()
|
||||||
|
if isinstance(att_ws, list): # multi-encoder case
|
||||||
|
num_encs = len(att_ws) - 1
|
||||||
|
# atts
|
||||||
|
for i in range(num_encs):
|
||||||
|
for idx, att_w in enumerate(att_ws[i]):
|
||||||
|
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
||||||
|
plot = self.draw_attention_plot(att_w)
|
||||||
|
logger.add_figure(
|
||||||
|
"%s_att%d" % (uttid_list[idx], i + 1),
|
||||||
|
plot.gcf(),
|
||||||
|
step, )
|
||||||
|
# han
|
||||||
|
for idx, att_w in enumerate(att_ws[num_encs]):
|
||||||
|
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
||||||
|
plot = self.draw_han_plot(att_w)
|
||||||
|
logger.add_figure(
|
||||||
|
"%s_han" % (uttid_list[idx]),
|
||||||
|
plot.gcf(),
|
||||||
|
step, )
|
||||||
|
else:
|
||||||
|
for idx, att_w in enumerate(att_ws):
|
||||||
|
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
|
||||||
|
plot = self.draw_attention_plot(att_w)
|
||||||
|
logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step)
|
||||||
|
|
||||||
|
def get_attention_weights(self):
|
||||||
|
"""Return attention weights.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
numpy.ndarray: attention weights. float. Its shape would be
|
||||||
|
differ from backend.
|
||||||
|
* pytorch-> 1) multi-head case => (B, H, Lmax, Tmax), 2)
|
||||||
|
other case => (B, Lmax, Tmax).
|
||||||
|
* chainer-> (B, Lmax, Tmax)
|
||||||
|
|
||||||
|
"""
|
||||||
|
return_batch, uttid_list = self.transform(self.data, return_uttid=True)
|
||||||
|
batch = self.converter([return_batch], self.device)
|
||||||
|
if isinstance(batch, tuple):
|
||||||
|
att_ws = self.att_vis_fn(*batch)
|
||||||
|
else:
|
||||||
|
att_ws = self.att_vis_fn(**batch)
|
||||||
|
return att_ws, uttid_list
|
||||||
|
|
||||||
|
def trim_attention_weight(self, uttid, att_w):
|
||||||
|
"""Transform attention matrix with regard to self.reverse."""
|
||||||
|
if self.reverse:
|
||||||
|
enc_key, enc_axis = self.okey, self.oaxis
|
||||||
|
dec_key, dec_axis = self.ikey, self.iaxis
|
||||||
|
else:
|
||||||
|
enc_key, enc_axis = self.ikey, self.iaxis
|
||||||
|
dec_key, dec_axis = self.okey, self.oaxis
|
||||||
|
dec_len = int(self.data_dict[uttid][dec_key][dec_axis]["shape"][0])
|
||||||
|
enc_len = int(self.data_dict[uttid][enc_key][enc_axis]["shape"][0])
|
||||||
|
if self.factor > 1:
|
||||||
|
enc_len //= self.factor
|
||||||
|
if len(att_w.shape) == 3:
|
||||||
|
att_w = att_w[:, :dec_len, :enc_len]
|
||||||
|
else:
|
||||||
|
att_w = att_w[:dec_len, :enc_len]
|
||||||
|
return att_w
|
||||||
|
|
||||||
|
def draw_attention_plot(self, att_w):
|
||||||
|
"""Plot the att_w matrix.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
matplotlib.pyplot: pyplot object with attention matrix image.
|
||||||
|
|
||||||
|
"""
|
||||||
|
import matplotlib
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
plt.clf()
|
||||||
|
att_w = att_w.astype(np.float32)
|
||||||
|
if len(att_w.shape) == 3:
|
||||||
|
for h, aw in enumerate(att_w, 1):
|
||||||
|
plt.subplot(1, len(att_w), h)
|
||||||
|
plt.imshow(aw, aspect="auto")
|
||||||
|
plt.xlabel("Encoder Index")
|
||||||
|
plt.ylabel("Decoder Index")
|
||||||
|
else:
|
||||||
|
plt.imshow(att_w, aspect="auto")
|
||||||
|
plt.xlabel("Encoder Index")
|
||||||
|
plt.ylabel("Decoder Index")
|
||||||
|
plt.tight_layout()
|
||||||
|
return plt
|
||||||
|
|
||||||
|
def draw_han_plot(self, att_w):
|
||||||
|
"""Plot the att_w matrix for hierarchical attention.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
matplotlib.pyplot: pyplot object with attention matrix image.
|
||||||
|
|
||||||
|
"""
|
||||||
|
import matplotlib
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
plt.clf()
|
||||||
|
if len(att_w.shape) == 3:
|
||||||
|
for h, aw in enumerate(att_w, 1):
|
||||||
|
legends = []
|
||||||
|
plt.subplot(1, len(att_w), h)
|
||||||
|
for i in range(aw.shape[1]):
|
||||||
|
plt.plot(aw[:, i])
|
||||||
|
legends.append("Att{}".format(i))
|
||||||
|
plt.ylim([0, 1.0])
|
||||||
|
plt.xlim([0, aw.shape[0]])
|
||||||
|
plt.grid(True)
|
||||||
|
plt.ylabel("Attention Weight")
|
||||||
|
plt.xlabel("Decoder Index")
|
||||||
|
plt.legend(legends)
|
||||||
|
else:
|
||||||
|
legends = []
|
||||||
|
for i in range(att_w.shape[1]):
|
||||||
|
plt.plot(att_w[:, i])
|
||||||
|
legends.append("Att{}".format(i))
|
||||||
|
plt.ylim([0, 1.0])
|
||||||
|
plt.xlim([0, att_w.shape[0]])
|
||||||
|
plt.grid(True)
|
||||||
|
plt.ylabel("Attention Weight")
|
||||||
|
plt.xlabel("Decoder Index")
|
||||||
|
plt.legend(legends)
|
||||||
|
plt.tight_layout()
|
||||||
|
return plt
|
||||||
|
|
||||||
|
def _plot_and_save_attention(self, att_w, filename, han_mode=False):
|
||||||
|
if han_mode:
|
||||||
|
plt = self.draw_han_plot(att_w)
|
||||||
|
else:
|
||||||
|
plt = self.draw_attention_plot(att_w)
|
||||||
|
plt.savefig(filename)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
class PlotCTCReport(extension.Extension):
|
||||||
|
"""Plot CTC reporter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctc_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_ctc_probs):
|
||||||
|
Function of CTC visualization.
|
||||||
|
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
|
||||||
|
outdir (str): Directory to save figures.
|
||||||
|
converter (espnet.asr.*_backend.asr.CustomConverter):
|
||||||
|
Function to convert data.
|
||||||
|
device (int | torch.device): Device.
|
||||||
|
reverse (bool): If True, input and output length are reversed.
|
||||||
|
ikey (str): Key to access input
|
||||||
|
(for ASR/ST ikey="input", for MT ikey="output".)
|
||||||
|
iaxis (int): Dimension to access input
|
||||||
|
(for ASR/ST iaxis=0, for MT iaxis=1.)
|
||||||
|
okey (str): Key to access output
|
||||||
|
(for ASR/ST okey="input", MT okay="output".)
|
||||||
|
oaxis (int): Dimension to access output
|
||||||
|
(for ASR/ST oaxis=0, for MT oaxis=0.)
|
||||||
|
subsampling_factor (int): subsampling factor in encoder
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ctc_vis_fn,
|
||||||
|
data,
|
||||||
|
outdir,
|
||||||
|
converter,
|
||||||
|
transform,
|
||||||
|
device,
|
||||||
|
reverse=False,
|
||||||
|
ikey="input",
|
||||||
|
iaxis=0,
|
||||||
|
okey="output",
|
||||||
|
oaxis=0,
|
||||||
|
subsampling_factor=1, ):
|
||||||
|
self.ctc_vis_fn = ctc_vis_fn
|
||||||
|
self.data = copy.deepcopy(data)
|
||||||
|
self.data_dict = {k: v for k, v in copy.deepcopy(data)}
|
||||||
|
# key is utterance ID
|
||||||
|
self.outdir = outdir
|
||||||
|
self.converter = converter
|
||||||
|
self.transform = transform
|
||||||
|
self.device = device
|
||||||
|
self.reverse = reverse
|
||||||
|
self.ikey = ikey
|
||||||
|
self.iaxis = iaxis
|
||||||
|
self.okey = okey
|
||||||
|
self.oaxis = oaxis
|
||||||
|
self.factor = subsampling_factor
|
||||||
|
if not os.path.exists(self.outdir):
|
||||||
|
os.makedirs(self.outdir)
|
||||||
|
|
||||||
|
def __call__(self, trainer):
|
||||||
|
"""Plot and save image file of ctc prob."""
|
||||||
|
ctc_probs, uttid_list = self.get_ctc_probs()
|
||||||
|
if isinstance(ctc_probs, list): # multi-encoder case
|
||||||
|
num_encs = len(ctc_probs) - 1
|
||||||
|
for i in range(num_encs):
|
||||||
|
for idx, ctc_prob in enumerate(ctc_probs[i]):
|
||||||
|
filename = "%s/%s.ep.{.updater.epoch}.ctc%d.png" % (
|
||||||
|
self.outdir, uttid_list[idx], i + 1, )
|
||||||
|
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
|
||||||
|
np_filename = "%s/%s.ep.{.updater.epoch}.ctc%d.npy" % (
|
||||||
|
self.outdir, uttid_list[idx], i + 1, )
|
||||||
|
np.save(np_filename.format(trainer), ctc_prob)
|
||||||
|
self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
|
||||||
|
else:
|
||||||
|
for idx, ctc_prob in enumerate(ctc_probs):
|
||||||
|
filename = "%s/%s.ep.{.updater.epoch}.png" % (self.outdir,
|
||||||
|
uttid_list[idx], )
|
||||||
|
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
|
||||||
|
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
|
||||||
|
self.outdir, uttid_list[idx], )
|
||||||
|
np.save(np_filename.format(trainer), ctc_prob)
|
||||||
|
self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
|
||||||
|
|
||||||
|
def log_ctc_probs(self, logger, step):
|
||||||
|
"""Add image files of ctc probs to the tensorboard."""
|
||||||
|
ctc_probs, uttid_list = self.get_ctc_probs()
|
||||||
|
if isinstance(ctc_probs, list): # multi-encoder case
|
||||||
|
num_encs = len(ctc_probs) - 1
|
||||||
|
for i in range(num_encs):
|
||||||
|
for idx, ctc_prob in enumerate(ctc_probs[i]):
|
||||||
|
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
|
||||||
|
plot = self.draw_ctc_plot(ctc_prob)
|
||||||
|
logger.add_figure(
|
||||||
|
"%s_ctc%d" % (uttid_list[idx], i + 1),
|
||||||
|
plot.gcf(),
|
||||||
|
step, )
|
||||||
|
else:
|
||||||
|
for idx, ctc_prob in enumerate(ctc_probs):
|
||||||
|
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
|
||||||
|
plot = self.draw_ctc_plot(ctc_prob)
|
||||||
|
logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step)
|
||||||
|
|
||||||
|
def get_ctc_probs(self):
|
||||||
|
"""Return CTC probs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
numpy.ndarray: CTC probs. float. Its shape would be
|
||||||
|
differ from backend. (B, Tmax, vocab).
|
||||||
|
|
||||||
|
"""
|
||||||
|
return_batch, uttid_list = self.transform(self.data, return_uttid=True)
|
||||||
|
batch = self.converter([return_batch], self.device)
|
||||||
|
if isinstance(batch, tuple):
|
||||||
|
probs = self.ctc_vis_fn(*batch)
|
||||||
|
else:
|
||||||
|
probs = self.ctc_vis_fn(**batch)
|
||||||
|
return probs, uttid_list
|
||||||
|
|
||||||
|
def trim_ctc_prob(self, uttid, prob):
|
||||||
|
"""Trim CTC posteriors accoding to input lengths."""
|
||||||
|
enc_len = int(self.data_dict[uttid][self.ikey][self.iaxis]["shape"][0])
|
||||||
|
if self.factor > 1:
|
||||||
|
enc_len //= self.factor
|
||||||
|
prob = prob[:enc_len]
|
||||||
|
return prob
|
||||||
|
|
||||||
|
def draw_ctc_plot(self, ctc_prob):
|
||||||
|
"""Plot the ctc_prob matrix.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
matplotlib.pyplot: pyplot object with CTC prob matrix image.
|
||||||
|
|
||||||
|
"""
|
||||||
|
import matplotlib
|
||||||
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
ctc_prob = ctc_prob.astype(np.float32)
|
||||||
|
|
||||||
|
plt.clf()
|
||||||
|
topk_ids = np.argsort(ctc_prob, axis=1)
|
||||||
|
n_frames, vocab = ctc_prob.shape
|
||||||
|
times_probs = np.arange(n_frames)
|
||||||
|
|
||||||
|
plt.figure(figsize=(20, 8))
|
||||||
|
|
||||||
|
# NOTE: index 0 is reserved for blank
|
||||||
|
for idx in set(topk_ids.reshape(-1).tolist()):
|
||||||
|
if idx == 0:
|
||||||
|
plt.plot(
|
||||||
|
times_probs,
|
||||||
|
ctc_prob[:, 0],
|
||||||
|
":",
|
||||||
|
label="<blank>",
|
||||||
|
color="grey")
|
||||||
|
else:
|
||||||
|
plt.plot(times_probs, ctc_prob[:, idx])
|
||||||
|
plt.xlabel(u"Input [frame]", fontsize=12)
|
||||||
|
plt.ylabel("Posteriors", fontsize=12)
|
||||||
|
plt.xticks(list(range(0, int(n_frames) + 1, 10)))
|
||||||
|
plt.yticks(list(range(0, 2, 1)))
|
||||||
|
plt.tight_layout()
|
||||||
|
return plt
|
||||||
|
|
||||||
|
def _plot_and_save_ctc(self, ctc_prob, filename):
|
||||||
|
plt = self.draw_ctc_plot(ctc_prob)
|
||||||
|
plt.savefig(filename)
|
||||||
|
plt.close()
|
@ -0,0 +1,61 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from ..reporter import DictSummary
|
||||||
|
from .utils import get_trigger
|
||||||
|
|
||||||
|
|
||||||
|
class CompareValueTrigger():
|
||||||
|
"""Trigger invoked when key value getting bigger or lower than before.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key (str) : Key of value.
|
||||||
|
compare_fn ((float, float) -> bool) : Function to compare the values.
|
||||||
|
trigger (tuple(int, str)) : Trigger that decide the comparison interval.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, key, compare_fn, trigger=(1, "epoch")):
|
||||||
|
self._key = key
|
||||||
|
self._best_value = None
|
||||||
|
self._interval_trigger = get_trigger(trigger)
|
||||||
|
self._init_summary()
|
||||||
|
self._compare_fn = compare_fn
|
||||||
|
|
||||||
|
def __call__(self, trainer):
|
||||||
|
"""Get value related to the key and compare with current value."""
|
||||||
|
observation = trainer.observation
|
||||||
|
summary = self._summary
|
||||||
|
key = self._key
|
||||||
|
if key in observation:
|
||||||
|
summary.add({key: observation[key]})
|
||||||
|
|
||||||
|
if not self._interval_trigger(trainer):
|
||||||
|
return False
|
||||||
|
|
||||||
|
stats = summary.compute_mean()
|
||||||
|
value = float(stats[key]) # copy to CPU
|
||||||
|
self._init_summary()
|
||||||
|
|
||||||
|
if self._best_value is None:
|
||||||
|
# initialize best value
|
||||||
|
self._best_value = value
|
||||||
|
return False
|
||||||
|
elif self._compare_fn(self._best_value, value):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self._best_value = value
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _init_summary(self):
|
||||||
|
self._summary = DictSummary()
|
@ -0,0 +1,28 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from .interval_trigger import IntervalTrigger
|
||||||
|
|
||||||
|
|
||||||
|
def never_fail_trigger(trainer):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_trigger(trigger):
|
||||||
|
if trigger is None:
|
||||||
|
return never_fail_trigger
|
||||||
|
if callable(trigger):
|
||||||
|
return trigger
|
||||||
|
else:
|
||||||
|
trigger = IntervalTrigger(*trigger)
|
||||||
|
return trigger
|
Loading…
Reference in new issue