tiny change, not important

pull/930/head
huangyuxin 4 years ago
commit 9e2773dffb

@ -355,7 +355,6 @@ if not hasattr(paddle.Tensor, 'tolist'):
"register user tolist to paddle.Tensor, remove this when fixed!") "register user tolist to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'tolist', tolist) setattr(paddle.Tensor, 'tolist', tolist)
########### hack paddle.nn ############# ########### hack paddle.nn #############
from paddle.nn import Layer from paddle.nn import Layer
from typing import Optional from typing import Optional
@ -506,5 +505,3 @@ if not hasattr(paddle.nn, 'LayerDict'):
logger.debug( logger.debug(
"register user LayerDict to paddle.nn, remove this when fixed!") "register user LayerDict to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'LayerDict', LayerDict) setattr(paddle.nn, 'LayerDict', LayerDict)

@ -12,12 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""V2 backend for `asr_recog.py` using py:class:`decoders.beam_search.BeamSearch`.""" """V2 backend for `asr_recog.py` using py:class:`decoders.beam_search.BeamSearch`."""
import json
from pathlib import Path
import jsonlines import jsonlines
import paddle import paddle
import yaml
from yacs.config import CfgNode from yacs.config import CfgNode
from .beam_search import BatchBeamSearch from .beam_search import BatchBeamSearch

@ -21,8 +21,6 @@ from distutils.util import strtobool
import configargparse import configargparse
import numpy as np import numpy as np
from .recog import recog_v2
def get_parser(): def get_parser():
"""Get default arguments.""" """Get default arguments."""

@ -18,7 +18,7 @@ from deepspeech.utils.dynamic_import import dynamic_import
class ASRInterface: class ASRInterface:
"""ASR Interface for ESPnet model implementation.""" """ASR Interface model implementation."""
@staticmethod @staticmethod
def add_arguments(parser): def add_arguments(parser):
@ -103,14 +103,14 @@ class ASRInterface:
@property @property
def attention_plot_class(self): def attention_plot_class(self):
"""Get attention plot class.""" """Get attention plot class."""
from espnet.asr.asr_utils import PlotAttentionReport from deepspeech.training.extensions.plot import PlotAttentionReport
return PlotAttentionReport return PlotAttentionReport
@property @property
def ctc_plot_class(self): def ctc_plot_class(self):
"""Get CTC plot class.""" """Get CTC plot class."""
from espnet.asr.asr_utils import PlotCTCReport from deepspeech.training.extensions.plot import PlotCTCReport
return PlotCTCReport return PlotCTCReport

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -1,12 +1,25 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Language model interface.""" """Language model interface."""
import argparse import argparse
from deepspeech.decoders.scorers.scorer_interface import ScorerInterface from deepspeech.decoders.scorers.scorer_interface import ScorerInterface
from deepspeech.utils.dynamic_import import dynamic_import from deepspeech.utils.dynamic_import import dynamic_import
class LMInterface(ScorerInterface): class LMInterface(ScorerInterface):
"""LM Interface for ESPnet model implementation.""" """LM Interface model implementation."""
@staticmethod @staticmethod
def add_arguments(parser): def add_arguments(parser):
@ -52,6 +65,7 @@ predefined_lms = {
"transformer": "deepspeech.models.lm.transformer:TransformerLM", "transformer": "deepspeech.models.lm.transformer:TransformerLM",
} }
def dynamic_import_lm(module): def dynamic_import_lm(module):
"""Import LM class dynamically. """Import LM class dynamically.
@ -63,7 +77,6 @@ def dynamic_import_lm(module):
""" """
model_class = dynamic_import(module, predefined_lms) model_class = dynamic_import(module, predefined_lms)
assert issubclass( assert issubclass(model_class,
model_class, LMInterface LMInterface), f"{module} does not implement LMInterface"
), f"{module} does not implement LMInterface"
return model_class return model_class

@ -0,0 +1,75 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ST Interface module."""
from .asr_interface import ASRInterface
from deepspeech.utils.dynamic_import import dynamic_import
class STInterface(ASRInterface):
"""ST Interface model implementation.
NOTE: This class is inherited from ASRInterface to enable joint translation
and recognition when performing multi-task learning with the ASR task.
"""
def translate(self,
x,
trans_args,
char_list=None,
rnnlm=None,
ensemble_models=[]):
"""Recognize x for evaluation.
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
:param namespace trans_args: argment namespace contraining options
:param list char_list: list of characters
:param paddle.nn.Layer rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise NotImplementedError("translate method is not implemented")
def translate_batch(self, x, trans_args, char_list=None, rnnlm=None):
"""Beam search implementation for batch.
:param paddle.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
:param namespace trans_args: argument namespace containing options
:param list char_list: list of characters
:param paddle.nn.Layer rnnlm: language model module
:return: N-best decoding results
:rtype: list
"""
raise NotImplementedError("Batch decoding is not supported yet.")
predefined_st = {
"transformer": "deepspeech.models.u2_st:U2STModel",
}
def dynamic_import_st(module):
"""Import ST models dynamically.
Args:
module (str): module_name:class_name or alias in `predefined_st`
Returns:
type: ST class
"""
model_class = dynamic_import(module, predefined_st)
assert issubclass(model_class,
STInterface), f"{module} does not implement STInterface"
return model_class

@ -0,0 +1,15 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .u2_st import U2STInferModel
from .u2_st import U2STModel

@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Union
import paddle import paddle
from paddle import nn from paddle import nn
from typing import Union
from paddle.nn import functional as F from paddle.nn import functional as F
from typeguard import check_argument_types from typeguard import check_argument_types

@ -22,7 +22,10 @@ from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
__all__ = ["NoPositionalEncoding", "PositionalEncoding", "RelPositionalEncoding"] __all__ = [
"NoPositionalEncoding", "PositionalEncoding", "RelPositionalEncoding"
]
class NoPositionalEncoding(nn.Layer): class NoPositionalEncoding(nn.Layer):
def __init__(self, def __init__(self,

@ -103,7 +103,7 @@ class BaseEncoder(nn.Layer):
pos_enc_class = PositionalEncoding pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == "rel_pos": elif pos_enc_layer_type == "rel_pos":
pos_enc_class = RelPositionalEncoding pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type is "no_pos": elif pos_enc_layer_type == "no_pos":
pos_enc_class = NoPositionalEncoding pos_enc_class = NoPositionalEncoding
else: else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)

@ -61,7 +61,7 @@ class LinearNoSubsampling(BaseSubsampling):
nn.Linear(idim, odim), nn.Linear(idim, odim),
nn.LayerNorm(odim, epsilon=1e-12), nn.LayerNorm(odim, epsilon=1e-12),
nn.Dropout(dropout_rate), nn.Dropout(dropout_rate),
nn.ReLU(),) nn.ReLU(), )
self.right_context = 0 self.right_context = 0
self.subsampling_rate = 1 self.subsampling_rate = 1
@ -83,10 +83,12 @@ class LinearNoSubsampling(BaseSubsampling):
x, pos_emb = self.pos_enc(x, offset) x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask return x, pos_emb, x_mask
class Conv2dSubsampling(BaseSubsampling): class Conv2dSubsampling(BaseSubsampling):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
class Conv2dSubsampling4(Conv2dSubsampling): class Conv2dSubsampling4(Conv2dSubsampling):
"""Convolutional 2D subsampling (to 1/4 length).""" """Convolutional 2D subsampling (to 1/4 length)."""

@ -0,0 +1,418 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
import numpy as np
from . import extension
class PlotAttentionReport(extension.Extension):
"""Plot attention reporter.
Args:
att_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_attentions):
Function of attention visualization.
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
outdir (str): Directory to save figures.
converter (espnet.asr.*_backend.asr.CustomConverter):
Function to convert data.
device (int | torch.device): Device.
reverse (bool): If True, input and output length are reversed.
ikey (str): Key to access input
(for ASR/ST ikey="input", for MT ikey="output".)
iaxis (int): Dimension to access input
(for ASR/ST iaxis=0, for MT iaxis=1.)
okey (str): Key to access output
(for ASR/ST okey="input", MT okay="output".)
oaxis (int): Dimension to access output
(for ASR/ST oaxis=0, for MT oaxis=0.)
subsampling_factor (int): subsampling factor in encoder
"""
def __init__(
self,
att_vis_fn,
data,
outdir,
converter,
transform,
device,
reverse=False,
ikey="input",
iaxis=0,
okey="output",
oaxis=0,
subsampling_factor=1, ):
self.att_vis_fn = att_vis_fn
self.data = copy.deepcopy(data)
self.data_dict = {k: v for k, v in copy.deepcopy(data)}
# key is utterance ID
self.outdir = outdir
self.converter = converter
self.transform = transform
self.device = device
self.reverse = reverse
self.ikey = ikey
self.iaxis = iaxis
self.okey = okey
self.oaxis = oaxis
self.factor = subsampling_factor
if not os.path.exists(self.outdir):
os.makedirs(self.outdir)
def __call__(self, trainer):
"""Plot and save image file of att_ws matrix."""
att_ws, uttid_list = self.get_attention_weights()
if isinstance(att_ws, list): # multi-encoder case
num_encs = len(att_ws) - 1
# atts
for i in range(num_encs):
for idx, att_w in enumerate(att_ws[i]):
filename = "%s/%s.ep.{.updater.epoch}.att%d.png" % (
self.outdir, uttid_list[idx], i + 1, )
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
np_filename = "%s/%s.ep.{.updater.epoch}.att%d.npy" % (
self.outdir, uttid_list[idx], i + 1, )
np.save(np_filename.format(trainer), att_w)
self._plot_and_save_attention(att_w,
filename.format(trainer))
# han
for idx, att_w in enumerate(att_ws[num_encs]):
filename = "%s/%s.ep.{.updater.epoch}.han.png" % (
self.outdir, uttid_list[idx], )
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
np_filename = "%s/%s.ep.{.updater.epoch}.han.npy" % (
self.outdir, uttid_list[idx], )
np.save(np_filename.format(trainer), att_w)
self._plot_and_save_attention(
att_w, filename.format(trainer), han_mode=True)
else:
for idx, att_w in enumerate(att_ws):
filename = "%s/%s.ep.{.updater.epoch}.png" % (self.outdir,
uttid_list[idx], )
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
self.outdir, uttid_list[idx], )
np.save(np_filename.format(trainer), att_w)
self._plot_and_save_attention(att_w, filename.format(trainer))
def log_attentions(self, logger, step):
"""Add image files of att_ws matrix to the tensorboard."""
att_ws, uttid_list = self.get_attention_weights()
if isinstance(att_ws, list): # multi-encoder case
num_encs = len(att_ws) - 1
# atts
for i in range(num_encs):
for idx, att_w in enumerate(att_ws[i]):
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
plot = self.draw_attention_plot(att_w)
logger.add_figure(
"%s_att%d" % (uttid_list[idx], i + 1),
plot.gcf(),
step, )
# han
for idx, att_w in enumerate(att_ws[num_encs]):
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
plot = self.draw_han_plot(att_w)
logger.add_figure(
"%s_han" % (uttid_list[idx]),
plot.gcf(),
step, )
else:
for idx, att_w in enumerate(att_ws):
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
plot = self.draw_attention_plot(att_w)
logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step)
def get_attention_weights(self):
"""Return attention weights.
Returns:
numpy.ndarray: attention weights. float. Its shape would be
differ from backend.
* pytorch-> 1) multi-head case => (B, H, Lmax, Tmax), 2)
other case => (B, Lmax, Tmax).
* chainer-> (B, Lmax, Tmax)
"""
return_batch, uttid_list = self.transform(self.data, return_uttid=True)
batch = self.converter([return_batch], self.device)
if isinstance(batch, tuple):
att_ws = self.att_vis_fn(*batch)
else:
att_ws = self.att_vis_fn(**batch)
return att_ws, uttid_list
def trim_attention_weight(self, uttid, att_w):
"""Transform attention matrix with regard to self.reverse."""
if self.reverse:
enc_key, enc_axis = self.okey, self.oaxis
dec_key, dec_axis = self.ikey, self.iaxis
else:
enc_key, enc_axis = self.ikey, self.iaxis
dec_key, dec_axis = self.okey, self.oaxis
dec_len = int(self.data_dict[uttid][dec_key][dec_axis]["shape"][0])
enc_len = int(self.data_dict[uttid][enc_key][enc_axis]["shape"][0])
if self.factor > 1:
enc_len //= self.factor
if len(att_w.shape) == 3:
att_w = att_w[:, :dec_len, :enc_len]
else:
att_w = att_w[:dec_len, :enc_len]
return att_w
def draw_attention_plot(self, att_w):
"""Plot the att_w matrix.
Returns:
matplotlib.pyplot: pyplot object with attention matrix image.
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
plt.clf()
att_w = att_w.astype(np.float32)
if len(att_w.shape) == 3:
for h, aw in enumerate(att_w, 1):
plt.subplot(1, len(att_w), h)
plt.imshow(aw, aspect="auto")
plt.xlabel("Encoder Index")
plt.ylabel("Decoder Index")
else:
plt.imshow(att_w, aspect="auto")
plt.xlabel("Encoder Index")
plt.ylabel("Decoder Index")
plt.tight_layout()
return plt
def draw_han_plot(self, att_w):
"""Plot the att_w matrix for hierarchical attention.
Returns:
matplotlib.pyplot: pyplot object with attention matrix image.
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
plt.clf()
if len(att_w.shape) == 3:
for h, aw in enumerate(att_w, 1):
legends = []
plt.subplot(1, len(att_w), h)
for i in range(aw.shape[1]):
plt.plot(aw[:, i])
legends.append("Att{}".format(i))
plt.ylim([0, 1.0])
plt.xlim([0, aw.shape[0]])
plt.grid(True)
plt.ylabel("Attention Weight")
plt.xlabel("Decoder Index")
plt.legend(legends)
else:
legends = []
for i in range(att_w.shape[1]):
plt.plot(att_w[:, i])
legends.append("Att{}".format(i))
plt.ylim([0, 1.0])
plt.xlim([0, att_w.shape[0]])
plt.grid(True)
plt.ylabel("Attention Weight")
plt.xlabel("Decoder Index")
plt.legend(legends)
plt.tight_layout()
return plt
def _plot_and_save_attention(self, att_w, filename, han_mode=False):
if han_mode:
plt = self.draw_han_plot(att_w)
else:
plt = self.draw_attention_plot(att_w)
plt.savefig(filename)
plt.close()
class PlotCTCReport(extension.Extension):
"""Plot CTC reporter.
Args:
ctc_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_ctc_probs):
Function of CTC visualization.
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
outdir (str): Directory to save figures.
converter (espnet.asr.*_backend.asr.CustomConverter):
Function to convert data.
device (int | torch.device): Device.
reverse (bool): If True, input and output length are reversed.
ikey (str): Key to access input
(for ASR/ST ikey="input", for MT ikey="output".)
iaxis (int): Dimension to access input
(for ASR/ST iaxis=0, for MT iaxis=1.)
okey (str): Key to access output
(for ASR/ST okey="input", MT okay="output".)
oaxis (int): Dimension to access output
(for ASR/ST oaxis=0, for MT oaxis=0.)
subsampling_factor (int): subsampling factor in encoder
"""
def __init__(
self,
ctc_vis_fn,
data,
outdir,
converter,
transform,
device,
reverse=False,
ikey="input",
iaxis=0,
okey="output",
oaxis=0,
subsampling_factor=1, ):
self.ctc_vis_fn = ctc_vis_fn
self.data = copy.deepcopy(data)
self.data_dict = {k: v for k, v in copy.deepcopy(data)}
# key is utterance ID
self.outdir = outdir
self.converter = converter
self.transform = transform
self.device = device
self.reverse = reverse
self.ikey = ikey
self.iaxis = iaxis
self.okey = okey
self.oaxis = oaxis
self.factor = subsampling_factor
if not os.path.exists(self.outdir):
os.makedirs(self.outdir)
def __call__(self, trainer):
"""Plot and save image file of ctc prob."""
ctc_probs, uttid_list = self.get_ctc_probs()
if isinstance(ctc_probs, list): # multi-encoder case
num_encs = len(ctc_probs) - 1
for i in range(num_encs):
for idx, ctc_prob in enumerate(ctc_probs[i]):
filename = "%s/%s.ep.{.updater.epoch}.ctc%d.png" % (
self.outdir, uttid_list[idx], i + 1, )
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
np_filename = "%s/%s.ep.{.updater.epoch}.ctc%d.npy" % (
self.outdir, uttid_list[idx], i + 1, )
np.save(np_filename.format(trainer), ctc_prob)
self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
else:
for idx, ctc_prob in enumerate(ctc_probs):
filename = "%s/%s.ep.{.updater.epoch}.png" % (self.outdir,
uttid_list[idx], )
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
self.outdir, uttid_list[idx], )
np.save(np_filename.format(trainer), ctc_prob)
self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
def log_ctc_probs(self, logger, step):
"""Add image files of ctc probs to the tensorboard."""
ctc_probs, uttid_list = self.get_ctc_probs()
if isinstance(ctc_probs, list): # multi-encoder case
num_encs = len(ctc_probs) - 1
for i in range(num_encs):
for idx, ctc_prob in enumerate(ctc_probs[i]):
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
plot = self.draw_ctc_plot(ctc_prob)
logger.add_figure(
"%s_ctc%d" % (uttid_list[idx], i + 1),
plot.gcf(),
step, )
else:
for idx, ctc_prob in enumerate(ctc_probs):
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
plot = self.draw_ctc_plot(ctc_prob)
logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step)
def get_ctc_probs(self):
"""Return CTC probs.
Returns:
numpy.ndarray: CTC probs. float. Its shape would be
differ from backend. (B, Tmax, vocab).
"""
return_batch, uttid_list = self.transform(self.data, return_uttid=True)
batch = self.converter([return_batch], self.device)
if isinstance(batch, tuple):
probs = self.ctc_vis_fn(*batch)
else:
probs = self.ctc_vis_fn(**batch)
return probs, uttid_list
def trim_ctc_prob(self, uttid, prob):
"""Trim CTC posteriors accoding to input lengths."""
enc_len = int(self.data_dict[uttid][self.ikey][self.iaxis]["shape"][0])
if self.factor > 1:
enc_len //= self.factor
prob = prob[:enc_len]
return prob
def draw_ctc_plot(self, ctc_prob):
"""Plot the ctc_prob matrix.
Returns:
matplotlib.pyplot: pyplot object with CTC prob matrix image.
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
ctc_prob = ctc_prob.astype(np.float32)
plt.clf()
topk_ids = np.argsort(ctc_prob, axis=1)
n_frames, vocab = ctc_prob.shape
times_probs = np.arange(n_frames)
plt.figure(figsize=(20, 8))
# NOTE: index 0 is reserved for blank
for idx in set(topk_ids.reshape(-1).tolist()):
if idx == 0:
plt.plot(
times_probs,
ctc_prob[:, 0],
":",
label="<blank>",
color="grey")
else:
plt.plot(times_probs, ctc_prob[:, idx])
plt.xlabel(u"Input [frame]", fontsize=12)
plt.ylabel("Posteriors", fontsize=12)
plt.xticks(list(range(0, int(n_frames) + 1, 10)))
plt.yticks(list(range(0, 2, 1)))
plt.tight_layout()
return plt
def _plot_and_save_ctc(self, ctc_prob, filename):
plt = self.draw_ctc_plot(ctc_prob)
plt.savefig(filename)
plt.close()

@ -11,18 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .interval_trigger import IntervalTrigger
def never_fail_trigger(trainer):
return False
def get_trigger(trigger):
if trigger is None:
return never_fail_trigger
if callable(trigger):
return trigger
else:
trigger = IntervalTrigger(*trigger)
return trigger

@ -0,0 +1,61 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..reporter import DictSummary
from .utils import get_trigger
class CompareValueTrigger():
"""Trigger invoked when key value getting bigger or lower than before.
Args:
key (str) : Key of value.
compare_fn ((float, float) -> bool) : Function to compare the values.
trigger (tuple(int, str)) : Trigger that decide the comparison interval.
"""
def __init__(self, key, compare_fn, trigger=(1, "epoch")):
self._key = key
self._best_value = None
self._interval_trigger = get_trigger(trigger)
self._init_summary()
self._compare_fn = compare_fn
def __call__(self, trainer):
"""Get value related to the key and compare with current value."""
observation = trainer.observation
summary = self._summary
key = self._key
if key in observation:
summary.add({key: observation[key]})
if not self._interval_trigger(trainer):
return False
stats = summary.compute_mean()
value = float(stats[key]) # copy to CPU
self._init_summary()
if self._best_value is None:
# initialize best value
self._best_value = value
return False
elif self._compare_fn(self._best_value, value):
return True
else:
self._best_value = value
return False
def _init_summary(self):
self._summary = DictSummary()

@ -30,3 +30,12 @@ class TimeTrigger():
return True return True
else: else:
return False return False
def state_dict(self):
state_dict = {
"next_time": self._next_time,
}
return state_dict
def set_state_dict(self, state_dict):
self._next_time = state_dict['next_time']

@ -0,0 +1,28 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .interval_trigger import IntervalTrigger
def never_fail_trigger(trainer):
return False
def get_trigger(trigger):
if trigger is None:
return never_fail_trigger
if callable(trigger):
return trigger
else:
trigger = IntervalTrigger(*trigger)
return trigger

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import json
import numpy as np import numpy as np
__all__ = ["label_smoothing_dist"] __all__ = ["label_smoothing_dist"]
@ -33,8 +34,7 @@ def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
if lsm_type == "unigram": if lsm_type == "unigram":
assert transcript is not None, ( assert transcript is not None, (
"transcript is required for %s label smoothing" % lsm_type "transcript is required for %s label smoothing" % lsm_type)
)
labelcount = np.zeros(odim) labelcount = np.zeros(odim)
for k, v in trans_json.items(): for k, v in trans_json.items():
ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()]) ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()])

@ -14,9 +14,9 @@
"""This module provides functions to calculate bleu score in different level. """This module provides functions to calculate bleu score in different level.
e.g. wer for word-level, cer for char-level. e.g. wer for word-level, cer for char-level.
""" """
import sacrebleu
import nltk import nltk
import numpy as np import numpy as np
import sacrebleu
__all__ = ['bleu', 'char_bleu', "ErrorCalculator"] __all__ = ['bleu', 'char_bleu', "ErrorCalculator"]
@ -106,11 +106,14 @@ class ErrorCalculator():
# NOTE: padding index (-1) in y_true is used to pad y_hat # NOTE: padding index (-1) in y_true is used to pad y_hat
# because y_hats is not padded with -1 # because y_hats is not padded with -1
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]] seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]]
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1] seq_true = [
self.char_list[int(idx)] for idx in y_true if int(idx) != -1
]
seq_hat_text = "".join(seq_hat).replace(self.space, " ") seq_hat_text = "".join(seq_hat).replace(self.space, " ")
seq_hat_text = seq_hat_text.replace(self.pad, "") seq_hat_text = seq_hat_text.replace(self.pad, "")
seq_true_text = "".join(seq_true).replace(self.space, " ") seq_true_text = "".join(seq_true).replace(self.space, " ")
seqs_hat.append(seq_hat_text) seqs_hat.append(seq_hat_text)
seqs_true.append(seq_true_text) seqs_true.append(seq_true_text)
bleu = nltk.bleu_score.corpus_bleu([[ref] for ref in seqs_true], seqs_hat) bleu = nltk.bleu_score.corpus_bleu([[ref] for ref in seqs_true],
seqs_hat)
return bleu * 100 return bleu * 100

@ -14,11 +14,10 @@
"""This module provides functions to calculate error rate in different level. """This module provides functions to calculate error rate in different level.
e.g. wer for word-level, cer for char-level. e.g. wer for word-level, cer for char-level.
""" """
from itertools import groupby
import editdistance import editdistance
import numpy as np import numpy as np
import logging
import sys
from itertools import groupby
__all__ = ['word_errors', 'char_errors', 'wer', 'cer', "ErrorCalculator"] __all__ = ['word_errors', 'char_errors', 'wer', 'cer', "ErrorCalculator"]
@ -225,9 +224,12 @@ class ErrorCalculator():
:return: :return:
""" """
def __init__( def __init__(self,
self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False char_list,
): sym_space,
sym_blank,
report_cer=False,
report_wer=False):
"""Construct an ErrorCalculator object.""" """Construct an ErrorCalculator object."""
super().__init__() super().__init__()
@ -317,7 +319,9 @@ class ErrorCalculator():
ymax = eos_true[0] if len(eos_true) > 0 else len(y_true) ymax = eos_true[0] if len(eos_true) > 0 else len(y_true)
# NOTE: padding index (-1) in y_true is used to pad y_hat # NOTE: padding index (-1) in y_true is used to pad y_hat
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]] seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]]
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1] seq_true = [
self.char_list[int(idx)] for idx in y_true if int(idx) != -1
]
seq_hat_text = "".join(seq_hat).replace(self.space, " ") seq_hat_text = "".join(seq_hat).replace(self.space, " ")
seq_hat_text = seq_hat_text.replace(self.blank, "") seq_hat_text = seq_hat_text.replace(self.blank, "")
seq_true_text = "".join(seq_true).replace(self.space, " ") seq_true_text = "".join(seq_true).replace(self.space, " ")

@ -15,7 +15,6 @@ import contextlib
import inspect import inspect
import io import io
import os import os
import re
import subprocess as sp import subprocess as sp
import sys import sys
from pathlib import Path from pathlib import Path
@ -84,7 +83,7 @@ def _post_install(install_lib_dir):
tools_extrs_dir = HERE / 'tools/extras' tools_extrs_dir = HERE / 'tools/extras'
with pushd(tools_extrs_dir): with pushd(tools_extrs_dir):
print(os.getcwd()) print(os.getcwd())
check_call(f"./install_autolog.sh") check_call("./install_autolog.sh")
print("autolog install.") print("autolog install.")
# ctcdecoder # ctcdecoder

@ -4,7 +4,6 @@
# 2018 Xuankai Chang (Shanghai Jiao Tong University) # 2018 Xuankai Chang (Shanghai Jiao Tong University)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse import argparse
import json
import logging import logging
import sys import sys

Loading…
Cancel
Save