format code

pull/931/head
Hui Zhang 3 years ago
parent 3fa2e44e89
commit e4ecfb22fd

@ -355,7 +355,6 @@ if not hasattr(paddle.Tensor, 'tolist'):
"register user tolist to paddle.Tensor, remove this when fixed!") "register user tolist to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'tolist', tolist) setattr(paddle.Tensor, 'tolist', tolist)
########### hack paddle.nn ############# ########### hack paddle.nn #############
from paddle.nn import Layer from paddle.nn import Layer
from typing import Optional from typing import Optional
@ -506,5 +505,3 @@ if not hasattr(paddle.nn, 'LayerDict'):
logger.debug( logger.debug(
"register user LayerDict to paddle.nn, remove this when fixed!") "register user LayerDict to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'LayerDict', LayerDict) setattr(paddle.nn, 'LayerDict', LayerDict)

@ -12,12 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""V2 backend for `asr_recog.py` using py:class:`decoders.beam_search.BeamSearch`.""" """V2 backend for `asr_recog.py` using py:class:`decoders.beam_search.BeamSearch`."""
import json
from pathlib import Path
import jsonlines import jsonlines
import paddle import paddle
import yaml
from yacs.config import CfgNode from yacs.config import CfgNode
from .beam_search import BatchBeamSearch from .beam_search import BatchBeamSearch
@ -79,8 +75,7 @@ def recog_v2(args):
sort_in_input_length=False, sort_in_input_length=False,
preprocess_conf=confs.collator.augmentation_config preprocess_conf=confs.collator.augmentation_config
if args.preprocess_conf is None else args.preprocess_conf, if args.preprocess_conf is None else args.preprocess_conf,
preprocess_args={"train": False}, preprocess_args={"train": False}, )
)
if args.rnnlm: if args.rnnlm:
lm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) lm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
@ -113,8 +108,7 @@ def recog_v2(args):
ctc=args.ctc_weight, ctc=args.ctc_weight,
lm=args.lm_weight, lm=args.lm_weight,
ngram=args.ngram_weight, ngram=args.ngram_weight,
length_bonus=args.penalty, length_bonus=args.penalty, )
)
beam_search = BeamSearch( beam_search = BeamSearch(
beam_size=args.beam_size, beam_size=args.beam_size,
vocab_size=len(char_list), vocab_size=len(char_list),
@ -123,8 +117,7 @@ def recog_v2(args):
sos=model.sos, sos=model.sos,
eos=model.eos, eos=model.eos,
token_list=char_list, token_list=char_list,
pre_beam_score_key=None if args.ctc_weight == 1.0 else "full", pre_beam_score_key=None if args.ctc_weight == 1.0 else "full", )
)
# TODO(karita): make all scorers batchfied # TODO(karita): make all scorers batchfied
if args.batchsize == 1: if args.batchsize == 1:
@ -171,7 +164,8 @@ def recog_v2(args):
logger.info(f'feat: {feat.shape}') logger.info(f'feat: {feat.shape}')
enc = model.encode(paddle.to_tensor(feat).to(dtype)) enc = model.encode(paddle.to_tensor(feat).to(dtype))
logger.info(f'eout: {enc.shape}') logger.info(f'eout: {enc.shape}')
nbest_hyps = beam_search(x=enc, nbest_hyps = beam_search(
x=enc,
maxlenratio=args.maxlenratio, maxlenratio=args.maxlenratio,
minlenratio=args.minlenratio) minlenratio=args.minlenratio)
nbest_hyps = [ nbest_hyps = [
@ -183,9 +177,8 @@ def recog_v2(args):
item = new_js[name]['output'][0] # 1-best item = new_js[name]['output'][0] # 1-best
ref = item['text'] ref = item['text']
rec_text = item['rec_text'].replace('', rec_text = item['rec_text'].replace('', ' ').replace(
' ').replace('<eos>', '<eos>', '').strip()
'').strip()
rec_tokenid = list(map(int, item['rec_tokenid'].split())) rec_tokenid = list(map(int, item['rec_tokenid'].split()))
f.write({ f.write({
"utt": name, "utt": name,

@ -21,8 +21,6 @@ from distutils.util import strtobool
import configargparse import configargparse
import numpy as np import numpy as np
from .recog import recog_v2
def get_parser(): def get_parser():
"""Get default arguments.""" """Get default arguments."""

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -20,11 +20,11 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from deepspeech.modules.mask import subsequent_mask
from deepspeech.modules.encoder import TransformerEncoder
from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface
from deepspeech.models.lm_interface import from deepspeech.models.lm_interface import LMInterface
#LMInterface from deepspeech.modules.encoder import TransformerEncoder
from deepspeech.modules.mask import subsequent_mask
class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface): class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
def __init__( def __init__(
@ -89,8 +89,7 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
m = subsequent_mask(ys_mask.size(-1)).unsqueeze(0) m = subsequent_mask(ys_mask.size(-1)).unsqueeze(0)
return ys_mask.unsqueeze(-2) & m return ys_mask.unsqueeze(-2) & m
def forward( def forward(self, x: paddle.Tensor, t: paddle.Tensor
self, x: paddle.Tensor, t: paddle.Tensor
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute LM loss value from buffer sequences. """Compute LM loss value from buffer sequences.
@ -117,7 +116,8 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
emb = self.embed(x) emb = self.embed(x)
h, _ = self.encoder(emb, xlen) h, _ = self.encoder(emb, xlen)
y = self.decoder(h) y = self.decoder(h)
loss = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none") loss = F.cross_entropy(
y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
mask = xm.to(dtype=loss.dtype) mask = xm.to(dtype=loss.dtype)
logp = loss * mask.view(-1) logp = loss * mask.view(-1)
logp = logp.sum() logp = logp.sum()
@ -148,16 +148,16 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
emb = self.embed(y) emb = self.embed(y)
h, _, cache = self.encoder.forward_one_step( h, _, cache = self.encoder.forward_one_step(
emb, self._target_mask(y), cache=state emb, self._target_mask(y), cache=state)
)
h = self.decoder(h[:, -1]) h = self.decoder(h[:, -1])
logp = h.log_softmax(axis=-1).squeeze(0) logp = h.log_softmax(axis=-1).squeeze(0)
return logp, cache return logp, cache
# batch beam search API (see BatchScorerInterface) # batch beam search API (see BatchScorerInterface)
def batch_score( def batch_score(self,
self, ys: paddle.Tensor, states: List[Any], xs: paddle.Tensor ys: paddle.Tensor,
) -> Tuple[paddle.Tensor, List[Any]]: states: List[Any],
xs: paddle.Tensor) -> Tuple[paddle.Tensor, List[Any]]:
"""Score new token batch (required). """Score new token batch (required).
Args: Args:
@ -191,13 +191,13 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
# batch decoding # batch decoding
h, _, states = self.encoder.forward_one_step( h, _, states = self.encoder.forward_one_step(
emb, self._target_mask(ys), cache=batch_state emb, self._target_mask(ys), cache=batch_state)
)
h = self.decoder(h[:, -1]) h = self.decoder(h[:, -1])
logp = h.log_softmax(axi=-1) logp = h.log_softmax(axi=-1)
# transpose state of [layer, batch] into [batch, layer] # transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] state_list = [[states[i][b] for i in range(n_layers)]
for b in range(n_batch)]
return logp, state_list return logp, state_list

@ -1,10 +1,23 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Language model interface.""" """Language model interface."""
import argparse import argparse
from deepspeech.decoders.scorers.scorer_interface import ScorerInterface from deepspeech.decoders.scorers.scorer_interface import ScorerInterface
from deepspeech.utils.dynamic_import import dynamic_import from deepspeech.utils.dynamic_import import dynamic_import
class LMInterface(ScorerInterface): class LMInterface(ScorerInterface):
"""LM Interface model implementation.""" """LM Interface model implementation."""
@ -52,6 +65,7 @@ predefined_lms = {
"transformer": "deepspeech.models.lm.transformer:TransformerLM", "transformer": "deepspeech.models.lm.transformer:TransformerLM",
} }
def dynamic_import_lm(module): def dynamic_import_lm(module):
"""Import LM class dynamically. """Import LM class dynamically.
@ -63,7 +77,6 @@ def dynamic_import_lm(module):
""" """
model_class = dynamic_import(module, predefined_lms) model_class = dynamic_import(module, predefined_lms)
assert issubclass( assert issubclass(model_class,
model_class, LMInterface LMInterface), f"{module} does not implement LMInterface"
), f"{module} does not implement LMInterface"
return model_class return model_class

@ -1,9 +1,20 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ST Interface module.""" """ST Interface module."""
import argparse
from deepspeech.utils.dynamic_import import dynamic_import
from .asr_interface import ASRInterface from .asr_interface import ASRInterface
from deepspeech.utils.dynamic_import import dynamic_import
class STInterface(ASRInterface): class STInterface(ASRInterface):
"""ST Interface model implementation. """ST Interface model implementation.
@ -13,7 +24,12 @@ class STInterface(ASRInterface):
""" """
def translate(self, x, trans_args, char_list=None, rnnlm=None, ensemble_models=[]): def translate(self,
x,
trans_args,
char_list=None,
rnnlm=None,
ensemble_models=[]):
"""Recognize x for evaluation. """Recognize x for evaluation.
:param ndarray x: input acouctic feature (B, T, D) or (T, D) :param ndarray x: input acouctic feature (B, T, D) or (T, D)
@ -42,6 +58,7 @@ predefined_st = {
"transformer": "deepspeech.models.u2_st:U2STModel", "transformer": "deepspeech.models.u2_st:U2STModel",
} }
def dynamic_import_st(module): def dynamic_import_st(module):
"""Import ST models dynamically. """Import ST models dynamically.
@ -53,7 +70,6 @@ def dynamic_import_st(module):
""" """
model_class = dynamic_import(module, predefined_st) model_class = dynamic_import(module, predefined_st)
assert issubclass( assert issubclass(model_class,
model_class, STInterface STInterface), f"{module} does not implement STInterface"
), f"{module} does not implement STInterface"
return model_class return model_class

@ -1,2 +1,15 @@
from .u2_st import U2STModel # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .u2_st import U2STInferModel from .u2_st import U2STInferModel
from .u2_st import U2STModel

@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Union
import paddle import paddle
from paddle import nn from paddle import nn
from typing import Union
from paddle.nn import functional as F from paddle.nn import functional as F
from typeguard import check_argument_types from typeguard import check_argument_types

@ -22,7 +22,10 @@ from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
__all__ = ["NoPositionalEncoding", "PositionalEncoding", "RelPositionalEncoding"] __all__ = [
"NoPositionalEncoding", "PositionalEncoding", "RelPositionalEncoding"
]
class NoPositionalEncoding(nn.Layer): class NoPositionalEncoding(nn.Layer):
def __init__(self, def __init__(self,

@ -24,9 +24,9 @@ from deepspeech.modules.activation import get_activation
from deepspeech.modules.attention import MultiHeadedAttention from deepspeech.modules.attention import MultiHeadedAttention
from deepspeech.modules.attention import RelPositionMultiHeadedAttention from deepspeech.modules.attention import RelPositionMultiHeadedAttention
from deepspeech.modules.conformer_convolution import ConvolutionModule from deepspeech.modules.conformer_convolution import ConvolutionModule
from deepspeech.modules.embedding import NoPositionalEncoding
from deepspeech.modules.embedding import PositionalEncoding from deepspeech.modules.embedding import PositionalEncoding
from deepspeech.modules.embedding import RelPositionalEncoding from deepspeech.modules.embedding import RelPositionalEncoding
from deepspeech.modules.embedding import NoPositionalEncoding
from deepspeech.modules.encoder_layer import ConformerEncoderLayer from deepspeech.modules.encoder_layer import ConformerEncoderLayer
from deepspeech.modules.encoder_layer import TransformerEncoderLayer from deepspeech.modules.encoder_layer import TransformerEncoderLayer
from deepspeech.modules.mask import add_optional_chunk_mask from deepspeech.modules.mask import add_optional_chunk_mask
@ -103,7 +103,7 @@ class BaseEncoder(nn.Layer):
pos_enc_class = PositionalEncoding pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == "rel_pos": elif pos_enc_layer_type == "rel_pos":
pos_enc_class = RelPositionalEncoding pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type is "no_pos": elif pos_enc_layer_type == "no_pos":
pos_enc_class = NoPositionalEncoding pos_enc_class = NoPositionalEncoding
else: else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
@ -378,8 +378,7 @@ class TransformerEncoder(BaseEncoder):
self, self,
xs: paddle.Tensor, xs: paddle.Tensor,
masks: paddle.Tensor, masks: paddle.Tensor,
cache=None, cache=None, ) -> Tuple[paddle.Tensor, paddle.Tensor]:
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Encode input frame. """Encode input frame.
Args: Args:
@ -397,7 +396,8 @@ class TransformerEncoder(BaseEncoder):
if isinstance(self.embed, Conv2dSubsampling): if isinstance(self.embed, Conv2dSubsampling):
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor #TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0) xs, pos_emb, masks = self.embed(
xs, masks.astype(xs.dtype), offset=0)
else: else:
xs = self.embed(xs) xs = self.embed(xs)
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor #TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor

@ -83,10 +83,12 @@ class LinearNoSubsampling(BaseSubsampling):
x, pos_emb = self.pos_enc(x, offset) x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask return x, pos_emb, x_mask
class Conv2dSubsampling(BaseSubsampling): class Conv2dSubsampling(BaseSubsampling):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
class Conv2dSubsampling4(Conv2dSubsampling): class Conv2dSubsampling4(Conv2dSubsampling):
"""Convolutional 2D subsampling (to 1/4 length).""" """Convolutional 2D subsampling (to 1/4 length)."""

@ -1,15 +1,22 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
import argparse #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy import copy
import json
import os import os
import shutil
import tempfile
import numpy as np
import numpy as np
from . import extension from . import extension
from ..updaters.trainer import Trainer
class PlotAttentionReport(extension.Extension): class PlotAttentionReport(extension.Extension):
@ -49,8 +56,7 @@ class PlotAttentionReport(extension.Extension):
iaxis=0, iaxis=0,
okey="output", okey="output",
oaxis=0, oaxis=0,
subsampling_factor=1, subsampling_factor=1, ):
):
self.att_vis_fn = att_vis_fn self.att_vis_fn = att_vis_fn
self.data = copy.deepcopy(data) self.data = copy.deepcopy(data)
self.data_dict = {k: v for k, v in copy.deepcopy(data)} self.data_dict = {k: v for k, v in copy.deepcopy(data)}
@ -77,44 +83,30 @@ class PlotAttentionReport(extension.Extension):
for i in range(num_encs): for i in range(num_encs):
for idx, att_w in enumerate(att_ws[i]): for idx, att_w in enumerate(att_ws[i]):
filename = "%s/%s.ep.{.updater.epoch}.att%d.png" % ( filename = "%s/%s.ep.{.updater.epoch}.att%d.png" % (
self.outdir, self.outdir, uttid_list[idx], i + 1, )
uttid_list[idx],
i + 1,
)
att_w = self.trim_attention_weight(uttid_list[idx], att_w) att_w = self.trim_attention_weight(uttid_list[idx], att_w)
np_filename = "%s/%s.ep.{.updater.epoch}.att%d.npy" % ( np_filename = "%s/%s.ep.{.updater.epoch}.att%d.npy" % (
self.outdir, self.outdir, uttid_list[idx], i + 1, )
uttid_list[idx],
i + 1,
)
np.save(np_filename.format(trainer), att_w) np.save(np_filename.format(trainer), att_w)
self._plot_and_save_attention(att_w, filename.format(trainer)) self._plot_and_save_attention(att_w,
filename.format(trainer))
# han # han
for idx, att_w in enumerate(att_ws[num_encs]): for idx, att_w in enumerate(att_ws[num_encs]):
filename = "%s/%s.ep.{.updater.epoch}.han.png" % ( filename = "%s/%s.ep.{.updater.epoch}.han.png" % (
self.outdir, self.outdir, uttid_list[idx], )
uttid_list[idx],
)
att_w = self.trim_attention_weight(uttid_list[idx], att_w) att_w = self.trim_attention_weight(uttid_list[idx], att_w)
np_filename = "%s/%s.ep.{.updater.epoch}.han.npy" % ( np_filename = "%s/%s.ep.{.updater.epoch}.han.npy" % (
self.outdir, self.outdir, uttid_list[idx], )
uttid_list[idx],
)
np.save(np_filename.format(trainer), att_w) np.save(np_filename.format(trainer), att_w)
self._plot_and_save_attention( self._plot_and_save_attention(
att_w, filename.format(trainer), han_mode=True att_w, filename.format(trainer), han_mode=True)
)
else: else:
for idx, att_w in enumerate(att_ws): for idx, att_w in enumerate(att_ws):
filename = "%s/%s.ep.{.updater.epoch}.png" % ( filename = "%s/%s.ep.{.updater.epoch}.png" % (self.outdir,
self.outdir, uttid_list[idx], )
uttid_list[idx],
)
att_w = self.trim_attention_weight(uttid_list[idx], att_w) att_w = self.trim_attention_weight(uttid_list[idx], att_w)
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % ( np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
self.outdir, self.outdir, uttid_list[idx], )
uttid_list[idx],
)
np.save(np_filename.format(trainer), att_w) np.save(np_filename.format(trainer), att_w)
self._plot_and_save_attention(att_w, filename.format(trainer)) self._plot_and_save_attention(att_w, filename.format(trainer))
@ -131,8 +123,7 @@ class PlotAttentionReport(extension.Extension):
logger.add_figure( logger.add_figure(
"%s_att%d" % (uttid_list[idx], i + 1), "%s_att%d" % (uttid_list[idx], i + 1),
plot.gcf(), plot.gcf(),
step, step, )
)
# han # han
for idx, att_w in enumerate(att_ws[num_encs]): for idx, att_w in enumerate(att_ws[num_encs]):
att_w = self.trim_attention_weight(uttid_list[idx], att_w) att_w = self.trim_attention_weight(uttid_list[idx], att_w)
@ -140,8 +131,7 @@ class PlotAttentionReport(extension.Extension):
logger.add_figure( logger.add_figure(
"%s_han" % (uttid_list[idx]), "%s_han" % (uttid_list[idx]),
plot.gcf(), plot.gcf(),
step, step, )
)
else: else:
for idx, att_w in enumerate(att_ws): for idx, att_w in enumerate(att_ws):
att_w = self.trim_attention_weight(uttid_list[idx], att_w) att_w = self.trim_attention_weight(uttid_list[idx], att_w)
@ -298,8 +288,7 @@ class PlotCTCReport(extension.Extension):
iaxis=0, iaxis=0,
okey="output", okey="output",
oaxis=0, oaxis=0,
subsampling_factor=1, subsampling_factor=1, ):
):
self.ctc_vis_fn = ctc_vis_fn self.ctc_vis_fn = ctc_vis_fn
self.data = copy.deepcopy(data) self.data = copy.deepcopy(data)
self.data_dict = {k: v for k, v in copy.deepcopy(data)} self.data_dict = {k: v for k, v in copy.deepcopy(data)}
@ -325,29 +314,19 @@ class PlotCTCReport(extension.Extension):
for i in range(num_encs): for i in range(num_encs):
for idx, ctc_prob in enumerate(ctc_probs[i]): for idx, ctc_prob in enumerate(ctc_probs[i]):
filename = "%s/%s.ep.{.updater.epoch}.ctc%d.png" % ( filename = "%s/%s.ep.{.updater.epoch}.ctc%d.png" % (
self.outdir, self.outdir, uttid_list[idx], i + 1, )
uttid_list[idx],
i + 1,
)
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob) ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
np_filename = "%s/%s.ep.{.updater.epoch}.ctc%d.npy" % ( np_filename = "%s/%s.ep.{.updater.epoch}.ctc%d.npy" % (
self.outdir, self.outdir, uttid_list[idx], i + 1, )
uttid_list[idx],
i + 1,
)
np.save(np_filename.format(trainer), ctc_prob) np.save(np_filename.format(trainer), ctc_prob)
self._plot_and_save_ctc(ctc_prob, filename.format(trainer)) self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
else: else:
for idx, ctc_prob in enumerate(ctc_probs): for idx, ctc_prob in enumerate(ctc_probs):
filename = "%s/%s.ep.{.updater.epoch}.png" % ( filename = "%s/%s.ep.{.updater.epoch}.png" % (self.outdir,
self.outdir, uttid_list[idx], )
uttid_list[idx],
)
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob) ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % ( np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
self.outdir, self.outdir, uttid_list[idx], )
uttid_list[idx],
)
np.save(np_filename.format(trainer), ctc_prob) np.save(np_filename.format(trainer), ctc_prob)
self._plot_and_save_ctc(ctc_prob, filename.format(trainer)) self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
@ -363,8 +342,7 @@ class PlotCTCReport(extension.Extension):
logger.add_figure( logger.add_figure(
"%s_ctc%d" % (uttid_list[idx], i + 1), "%s_ctc%d" % (uttid_list[idx], i + 1),
plot.gcf(), plot.gcf(),
step, step, )
)
else: else:
for idx, ctc_prob in enumerate(ctc_probs): for idx, ctc_prob in enumerate(ctc_probs):
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob) ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
@ -420,8 +398,11 @@ class PlotCTCReport(extension.Extension):
for idx in set(topk_ids.reshape(-1).tolist()): for idx in set(topk_ids.reshape(-1).tolist()):
if idx == 0: if idx == 0:
plt.plot( plt.plot(
times_probs, ctc_prob[:, 0], ":", label="<blank>", color="grey" times_probs,
) ctc_prob[:, 0],
":",
label="<blank>",
color="grey")
else: else:
plt.plot(times_probs, ctc_prob[:, idx]) plt.plot(times_probs, ctc_prob[:, idx])
plt.xlabel(u"Input [frame]", fontsize=12) plt.xlabel(u"Input [frame]", fontsize=12)

@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .utils import get_trigger
from ..reporter import DictSummary from ..reporter import DictSummary
from .utils import get_trigger
class CompareValueTrigger(): class CompareValueTrigger():
"""Trigger invoked when key value getting bigger or lower than before. """Trigger invoked when key value getting bigger or lower than before.
@ -24,6 +24,7 @@ class CompareValueTrigger():
trigger (tuple(int, str)) : Trigger that decide the comparison interval. trigger (tuple(int, str)) : Trigger that decide the comparison interval.
""" """
def __init__(self, key, compare_fn, trigger=(1, "epoch")): def __init__(self, key, compare_fn, trigger=(1, "epoch")):
self._key = key self._key = key
self._best_value = None self._best_value = None

@ -1,3 +1,16 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .interval_trigger import IntervalTrigger from .interval_trigger import IntervalTrigger

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import json
import numpy as np import numpy as np
__all__ = ["label_smoothing_dist"] __all__ = ["label_smoothing_dist"]
@ -33,8 +34,7 @@ def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
if lsm_type == "unigram": if lsm_type == "unigram":
assert transcript is not None, ( assert transcript is not None, (
"transcript is required for %s label smoothing" % lsm_type "transcript is required for %s label smoothing" % lsm_type)
)
labelcount = np.zeros(odim) labelcount = np.zeros(odim)
for k, v in trans_json.items(): for k, v in trans_json.items():
ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()]) ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()])

@ -14,9 +14,9 @@
"""This module provides functions to calculate bleu score in different level. """This module provides functions to calculate bleu score in different level.
e.g. wer for word-level, cer for char-level. e.g. wer for word-level, cer for char-level.
""" """
import sacrebleu
import nltk import nltk
import numpy as np import numpy as np
import sacrebleu
__all__ = ['bleu', 'char_bleu', "ErrorCalculator"] __all__ = ['bleu', 'char_bleu', "ErrorCalculator"]
@ -106,11 +106,14 @@ class ErrorCalculator():
# NOTE: padding index (-1) in y_true is used to pad y_hat # NOTE: padding index (-1) in y_true is used to pad y_hat
# because y_hats is not padded with -1 # because y_hats is not padded with -1
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]] seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]]
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1] seq_true = [
self.char_list[int(idx)] for idx in y_true if int(idx) != -1
]
seq_hat_text = "".join(seq_hat).replace(self.space, " ") seq_hat_text = "".join(seq_hat).replace(self.space, " ")
seq_hat_text = seq_hat_text.replace(self.pad, "") seq_hat_text = seq_hat_text.replace(self.pad, "")
seq_true_text = "".join(seq_true).replace(self.space, " ") seq_true_text = "".join(seq_true).replace(self.space, " ")
seqs_hat.append(seq_hat_text) seqs_hat.append(seq_hat_text)
seqs_true.append(seq_true_text) seqs_true.append(seq_true_text)
bleu = nltk.bleu_score.corpus_bleu([[ref] for ref in seqs_true], seqs_hat) bleu = nltk.bleu_score.corpus_bleu([[ref] for ref in seqs_true],
seqs_hat)
return bleu * 100 return bleu * 100

@ -14,11 +14,10 @@
"""This module provides functions to calculate error rate in different level. """This module provides functions to calculate error rate in different level.
e.g. wer for word-level, cer for char-level. e.g. wer for word-level, cer for char-level.
""" """
from itertools import groupby
import editdistance import editdistance
import numpy as np import numpy as np
import logging
import sys
from itertools import groupby
__all__ = ['word_errors', 'char_errors', 'wer', 'cer', "ErrorCalculator"] __all__ = ['word_errors', 'char_errors', 'wer', 'cer', "ErrorCalculator"]
@ -225,9 +224,12 @@ class ErrorCalculator():
:return: :return:
""" """
def __init__( def __init__(self,
self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False char_list,
): sym_space,
sym_blank,
report_cer=False,
report_wer=False):
"""Construct an ErrorCalculator object.""" """Construct an ErrorCalculator object."""
super().__init__() super().__init__()
@ -317,7 +319,9 @@ class ErrorCalculator():
ymax = eos_true[0] if len(eos_true) > 0 else len(y_true) ymax = eos_true[0] if len(eos_true) > 0 else len(y_true)
# NOTE: padding index (-1) in y_true is used to pad y_hat # NOTE: padding index (-1) in y_true is used to pad y_hat
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]] seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]]
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1] seq_true = [
self.char_list[int(idx)] for idx in y_true if int(idx) != -1
]
seq_hat_text = "".join(seq_hat).replace(self.space, " ") seq_hat_text = "".join(seq_hat).replace(self.space, " ")
seq_hat_text = seq_hat_text.replace(self.blank, "") seq_hat_text = seq_hat_text.replace(self.blank, "")
seq_true_text = "".join(seq_true).replace(self.space, " ") seq_true_text = "".join(seq_true).replace(self.space, " ")

@ -15,7 +15,6 @@ import contextlib
import inspect import inspect
import io import io
import os import os
import re
import subprocess as sp import subprocess as sp
import sys import sys
from pathlib import Path from pathlib import Path
@ -84,7 +83,7 @@ def _post_install(install_lib_dir):
tools_extrs_dir = HERE / 'tools/extras' tools_extrs_dir = HERE / 'tools/extras'
with pushd(tools_extrs_dir): with pushd(tools_extrs_dir):
print(os.getcwd()) print(os.getcwd())
check_call(f"./install_autolog.sh") check_call("./install_autolog.sh")
print("autolog install.") print("autolog install.")
# ctcdecoder # ctcdecoder

@ -4,7 +4,6 @@
# 2018 Xuankai Chang (Shanghai Jiao Tong University) # 2018 Xuankai Chang (Shanghai Jiao Tong University)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse import argparse
import json
import logging import logging
import sys import sys

Loading…
Cancel
Save