Merge branch 'develop' of https://github.com/PaddlePaddle/DeepSpeech into thchs30_MFA

pull/698/head
TianYuan 4 years ago
commit c0ee57d400

@ -34,9 +34,12 @@ from deepspeech.models.u2 import U2Model
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.scheduler import WarmupLR from deepspeech.training.scheduler import WarmupLR
from deepspeech.training.trainer import Trainer from deepspeech.training.trainer import Trainer
from deepspeech.utils import ctc_utils
from deepspeech.utils import error_rate from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools from deepspeech.utils import mp_tools
from deepspeech.utils import text_grid
from deepspeech.utils import utility
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
@ -278,7 +281,15 @@ class U2Trainer(Trainer):
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=SpeechCollator.from_config(config)) collate_fn=SpeechCollator.from_config(config))
logger.info("Setup train/valid/test Dataloader!") # return text token id
config.collator.keep_transcription_text = False
self.align_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator.from_config(config))
logger.info("Setup train/valid/test/align Dataloader!")
def setup_model(self): def setup_model(self):
config = self.config config = self.config
@ -353,7 +364,7 @@ class U2Tester(U2Trainer):
decoding_chunk_size=-1, # decoding chunk size. Defaults to -1. decoding_chunk_size=-1, # decoding chunk size. Defaults to -1.
# <0: for decoding, use full chunk. # <0: for decoding, use full chunk.
# >0: for decoding, use fixed chunk size as set. # >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here. # 0: used for training, it's prohibited here.
num_decoding_left_chunks=-1, # number of left chunks for decoding. Defaults to -1. num_decoding_left_chunks=-1, # number of left chunks for decoding. Defaults to -1.
simulate_streaming=False, # simulate streaming inference. Defaults to False. simulate_streaming=False, # simulate streaming inference. Defaults to False.
)) ))
@ -498,6 +509,73 @@ class U2Tester(U2Trainer):
except KeyboardInterrupt: except KeyboardInterrupt:
sys.exit(-1) sys.exit(-1)
@paddle.no_grad()
def align(self):
if self.config.decoding.batch_size > 1:
logger.fatal('alignment mode must be running with batch_size == 1')
sys.exit(1)
# xxx.align
assert self.args.result_file and self.args.result_file.endswith(
'.align')
self.model.eval()
logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
stride_ms = self.align_loader.collate_fn.stride_ms
token_dict = self.align_loader.collate_fn.vocab_list
with open(self.args.result_file, 'w') as fout:
# one example in batch
for i, batch in enumerate(self.align_loader):
key, feat, feats_length, target, target_length = batch
# 1. Encoder
encoder_out, encoder_mask = self.model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
# 2. alignment
ctc_probs = ctc_probs.squeeze(0)
target = target.squeeze(0)
alignment = ctc_utils.forced_align(ctc_probs, target)
logger.info("align ids", key[0], alignment)
fout.write('{} {}\n'.format(key[0], alignment))
# 3. gen praat
# segment alignment
align_segs = text_grid.segment_alignment(alignment)
logger.info("align tokens", key[0], align_segs)
# IntervalTier, List["start end token\n"]
subsample = utility.get_subsample(self.config)
tierformat = text_grid.align_to_tierformat(
align_segs, subsample, token_dict)
# write tier
align_output_path = os.path.join(
os.path.dirname(self.args.result_file), "align")
tier_path = os.path.join(align_output_path, key[0] + ".tier")
with open(tier_path, 'w') as f:
f.writelines(tierformat)
# write textgrid
textgrid_path = os.path.join(align_output_path,
key[0] + ".TextGrid")
second_per_frame = 1. / (1000. /
stride_ms) # 25ms window, 10ms stride
second_per_example = (
len(alignment) + 1) * subsample * second_per_frame
text_grid.generate_textgrid(
maxtime=second_per_example,
intervals=tierformat,
output=textgrid_path)
def run_align(self):
self.resume_or_scratch()
try:
self.align()
except KeyboardInterrupt:
sys.exit(-1)
def load_inferspec(self): def load_inferspec(self):
"""infer model and input spec. """infer model and input spec.

@ -18,8 +18,8 @@ import paddle
from paddle import distributed as dist from paddle import distributed as dist
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from deepspeech.utils import checkpoint
from deepspeech.utils import mp_tools from deepspeech.utils import mp_tools
from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
__all__ = ["Trainer"] __all__ = ["Trainer"]
@ -139,9 +139,9 @@ class Trainer():
"epoch": self.epoch, "epoch": self.epoch,
"lr": self.optimizer.get_lr() "lr": self.optimizer.get_lr()
}) })
checkpoint.save_parameters(self.checkpoint_dir, self.iteration self.checkpoint.add_checkpoint(self.checkpoint_dir, self.iteration
if tag is None else tag, self.model, if tag is None else tag, self.model,
self.optimizer, infos) self.optimizer, infos)
def resume_or_scratch(self): def resume_or_scratch(self):
"""Resume from latest checkpoint at checkpoints in the output """Resume from latest checkpoint at checkpoints in the output
@ -151,7 +151,7 @@ class Trainer():
resume training. resume training.
""" """
scratch = None scratch = None
infos = checkpoint.load_parameters( infos = self.checkpoint.load_latest_parameters(
self.model, self.model,
self.optimizer, self.optimizer,
checkpoint_dir=self.checkpoint_dir, checkpoint_dir=self.checkpoint_dir,
@ -180,7 +180,7 @@ class Trainer():
from_scratch = self.resume_or_scratch() from_scratch = self.resume_or_scratch()
if from_scratch: if from_scratch:
# save init model, i.e. 0 epoch # save init model, i.e. 0 epoch
self.save(tag='init') self.save(tag='init', infos=None)
self.lr_scheduler.step(self.iteration) self.lr_scheduler.step(self.iteration)
if self.parallel: if self.parallel:
@ -263,6 +263,10 @@ class Trainer():
self.checkpoint_dir = checkpoint_dir self.checkpoint_dir = checkpoint_dir
self.checkpoint = Checkpoint(
kbest_n=self.config.training.checkpoint.kbest_n,
latest_n=self.config.training.checkpoint.latest_n)
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
def destory(self): def destory(self):
"""Close visualizer to avoid hanging after training""" """Close visualizer to avoid hanging after training"""

@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import glob
import json import json
import os import os
import re import re
from pathlib import Path
from typing import Union from typing import Union
import paddle import paddle
@ -25,128 +27,260 @@ from deepspeech.utils.log import Log
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
__all__ = ["load_parameters", "save_parameters"] __all__ = ["Checkpoint"]
def _load_latest_checkpoint(checkpoint_dir: str) -> int: class Checkpoint(object):
"""Get the iteration number corresponding to the latest saved checkpoint. def __init__(self, kbest_n: int=5, latest_n: int=1):
Args: self.best_records: Mapping[Path, float] = {}
checkpoint_dir (str): the directory where checkpoint is saved. self.latest_records = []
Returns: self.kbest_n = kbest_n
int: the latest iteration number. -1 for no checkpoint to load. self.latest_n = latest_n
""" self._save_all = (kbest_n == -1)
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
if not os.path.isfile(checkpoint_record): def add_checkpoint(self,
return -1 checkpoint_dir,
tag_or_iteration,
# Fetch the latest checkpoint index. model,
with open(checkpoint_record, "rt") as handle: optimizer,
latest_checkpoint = handle.readlines()[-1].strip() infos,
iteration = int(latest_checkpoint.split(":")[-1]) metric_type="val_loss"):
return iteration if (metric_type not in infos.keys()):
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
def _save_record(checkpoint_dir: str, iteration: int): return
"""Save the iteration number of the latest model to be checkpoint record.
Args: #save best
checkpoint_dir (str): the directory where checkpoint is saved. if self._should_save_best(infos[metric_type]):
iteration (int): the latest iteration number. self._save_best_checkpoint_and_update(
Returns: infos[metric_type], checkpoint_dir, tag_or_iteration, model,
None optimizer, infos)
""" #save latest
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint") self._save_latest_checkpoint_and_update(
# Update the latest checkpoint index. checkpoint_dir, tag_or_iteration, model, optimizer, infos)
with open(checkpoint_record, "a+") as handle:
handle.write("model_checkpoint_path:{}\n".format(iteration)) if isinstance(tag_or_iteration, int):
self._save_checkpoint_record(checkpoint_dir, tag_or_iteration)
def load_parameters(model, def load_latest_parameters(self,
optimizer=None, model,
checkpoint_dir=None, optimizer=None,
checkpoint_path=None): checkpoint_dir=None,
"""Load a specific model checkpoint from disk. checkpoint_path=None):
Args: """Load a last model checkpoint from disk.
model (Layer): model to load parameters. Args:
optimizer (Optimizer, optional): optimizer to load states if needed. model (Layer): model to load parameters.
Defaults to None. optimizer (Optimizer, optional): optimizer to load states if needed.
checkpoint_dir (str, optional): the directory where checkpoint is saved. Defaults to None.
checkpoint_path (str, optional): if specified, load the checkpoint checkpoint_dir (str, optional): the directory where checkpoint is saved.
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will checkpoint_path (str, optional): if specified, load the checkpoint
be ignored. Defaults to None. stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
Returns: be ignored. Defaults to None.
configs (dict): epoch or step, lr and other meta info should be saved. Returns:
""" configs (dict): epoch or step, lr and other meta info should be saved.
configs = {} """
return self._load_parameters(model, optimizer, checkpoint_dir,
if checkpoint_path is not None: checkpoint_path, "checkpoint_latest")
tag = os.path.basename(checkpoint_path).split(":")[-1]
elif checkpoint_dir is not None: def load_best_parameters(self,
iteration = _load_latest_checkpoint(checkpoint_dir) model,
if iteration == -1: optimizer=None,
return configs checkpoint_dir=None,
checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration)) checkpoint_path=None):
else: """Load a last model checkpoint from disk.
raise ValueError( Args:
"At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!" model (Layer): model to load parameters.
) optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
rank = dist.get_rank() checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
params_path = checkpoint_path + ".pdparams" stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
model_dict = paddle.load(params_path) be ignored. Defaults to None.
model.set_state_dict(model_dict) Returns:
logger.info("Rank {}: loaded model from {}".format(rank, params_path)) configs (dict): epoch or step, lr and other meta info should be saved.
"""
optimizer_path = checkpoint_path + ".pdopt" return self._load_parameters(model, optimizer, checkpoint_dir,
if optimizer and os.path.isfile(optimizer_path): checkpoint_path, "checkpoint_best")
optimizer_dict = paddle.load(optimizer_path)
optimizer.set_state_dict(optimizer_dict) def _should_save_best(self, metric: float) -> bool:
logger.info("Rank {}: loaded optimizer state from {}".format( if not self._best_full():
rank, optimizer_path)) return True
info_path = re.sub('.pdparams$', '.json', params_path) # already full
if os.path.exists(info_path): worst_record_path = max(self.best_records, key=self.best_records.get)
with open(info_path, 'r') as fin: # worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0]
configs = json.load(fin) worst_metric = self.best_records[worst_record_path]
return configs return metric < worst_metric
def _best_full(self):
@mp_tools.rank_zero_only return (not self._save_all) and len(self.best_records) == self.kbest_n
def save_parameters(checkpoint_dir: str,
tag_or_iteration: Union[int, str], def _latest_full(self):
model: paddle.nn.Layer, return len(self.latest_records) == self.latest_n
optimizer: Optimizer=None,
infos: dict=None): def _save_best_checkpoint_and_update(self, metric, checkpoint_dir,
"""Checkpoint the latest trained model parameters. tag_or_iteration, model, optimizer,
Args: infos):
checkpoint_dir (str): the directory where checkpoint is saved. # remove the worst
tag_or_iteration (int or str): the latest iteration(step or epoch) number. if self._best_full():
model (Layer): model to be checkpointed. worst_record_path = max(self.best_records,
optimizer (Optimizer, optional): optimizer to be checkpointed. key=self.best_records.get)
Defaults to None. self.best_records.pop(worst_record_path)
infos (dict or None): any info you want to save. if (worst_record_path not in self.latest_records):
Returns: logger.info(
None "remove the worst checkpoint: {}".format(worst_record_path))
""" self._del_checkpoint(checkpoint_dir, worst_record_path)
checkpoint_path = os.path.join(checkpoint_dir,
"{}".format(tag_or_iteration)) # add the new one
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
model_dict = model.state_dict() optimizer, infos)
params_path = checkpoint_path + ".pdparams" self.best_records[tag_or_iteration] = metric
paddle.save(model_dict, params_path)
logger.info("Saved model to {}".format(params_path)) def _save_latest_checkpoint_and_update(
self, checkpoint_dir, tag_or_iteration, model, optimizer, infos):
if optimizer: # remove the old
opt_dict = optimizer.state_dict() if self._latest_full():
to_del_fn = self.latest_records.pop(0)
if (to_del_fn not in self.best_records.keys()):
logger.info(
"remove the latest checkpoint: {}".format(to_del_fn))
self._del_checkpoint(checkpoint_dir, to_del_fn)
self.latest_records.append(tag_or_iteration)
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
def _del_checkpoint(self, checkpoint_dir, tag_or_iteration):
checkpoint_path = os.path.join(checkpoint_dir,
"{}".format(tag_or_iteration))
for filename in glob.glob(checkpoint_path + ".*"):
os.remove(filename)
logger.info("delete file: {}".format(filename))
def _load_checkpoint_idx(self, checkpoint_record: str) -> int:
"""Get the iteration number corresponding to the latest saved checkpoint.
Args:
checkpoint_path (str): the saved path of checkpoint.
Returns:
int: the latest iteration number. -1 for no checkpoint to load.
"""
if not os.path.isfile(checkpoint_record):
return -1
# Fetch the latest checkpoint index.
with open(checkpoint_record, "rt") as handle:
latest_checkpoint = handle.readlines()[-1].strip()
iteration = int(latest_checkpoint.split(":")[-1])
return iteration
def _save_checkpoint_record(self, checkpoint_dir: str, iteration: int):
"""Save the iteration number of the latest model to be checkpoint record.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number.
Returns:
None
"""
checkpoint_record_latest = os.path.join(checkpoint_dir,
"checkpoint_latest")
checkpoint_record_best = os.path.join(checkpoint_dir, "checkpoint_best")
with open(checkpoint_record_best, "w") as handle:
for i in self.best_records.keys():
handle.write("model_checkpoint_path:{}\n".format(i))
with open(checkpoint_record_latest, "w") as handle:
for i in self.latest_records:
handle.write("model_checkpoint_path:{}\n".format(i))
def _load_parameters(self,
model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None,
checkpoint_file=None):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
checkpoint_file "checkpoint_latest" or "checkpoint_best"
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
configs = {}
if checkpoint_path is not None:
tag = os.path.basename(checkpoint_path).split(":")[-1]
elif checkpoint_dir is not None and checkpoint_file is not None:
checkpoint_record = os.path.join(checkpoint_dir, checkpoint_file)
iteration = self._load_checkpoint_idx(checkpoint_record)
if iteration == -1:
return configs
checkpoint_path = os.path.join(checkpoint_dir,
"{}".format(iteration))
else:
raise ValueError(
"At least one of 'checkpoint_dir' and 'checkpoint_file' and 'checkpoint_path' should be specified!"
)
rank = dist.get_rank()
params_path = checkpoint_path + ".pdparams"
model_dict = paddle.load(params_path)
model.set_state_dict(model_dict)
logger.info("Rank {}: loaded model from {}".format(rank, params_path))
optimizer_path = checkpoint_path + ".pdopt" optimizer_path = checkpoint_path + ".pdopt"
paddle.save(opt_dict, optimizer_path) if optimizer and os.path.isfile(optimizer_path):
logger.info("Saved optimzier state to {}".format(optimizer_path)) optimizer_dict = paddle.load(optimizer_path)
optimizer.set_state_dict(optimizer_dict)
logger.info("Rank {}: loaded optimizer state from {}".format(
rank, optimizer_path))
info_path = re.sub('.pdparams$', '.json', params_path)
if os.path.exists(info_path):
with open(info_path, 'r') as fin:
configs = json.load(fin)
return configs
@mp_tools.rank_zero_only
def _save_parameters(self,
checkpoint_dir: str,
tag_or_iteration: Union[int, str],
model: paddle.nn.Layer,
optimizer: Optimizer=None,
infos: dict=None):
"""Checkpoint the latest trained model parameters.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
tag_or_iteration (int or str): the latest iteration(step or epoch) number.
model (Layer): model to be checkpointed.
optimizer (Optimizer, optional): optimizer to be checkpointed.
Defaults to None.
infos (dict or None): any info you want to save.
Returns:
None
"""
checkpoint_path = os.path.join(checkpoint_dir,
"{}".format(tag_or_iteration))
model_dict = model.state_dict()
params_path = checkpoint_path + ".pdparams"
paddle.save(model_dict, params_path)
logger.info("Saved model to {}".format(params_path))
info_path = re.sub('.pdparams$', '.json', params_path) if optimizer:
infos = {} if infos is None else infos opt_dict = optimizer.state_dict()
with open(info_path, 'w') as fout: optimizer_path = checkpoint_path + ".pdopt"
data = json.dumps(infos) paddle.save(opt_dict, optimizer_path)
fout.write(data) logger.info("Saved optimzier state to {}".format(optimizer_path))
if isinstance(tag_or_iteration, int): info_path = re.sub('.pdparams$', '.json', params_path)
_save_record(checkpoint_dir, tag_or_iteration) infos = {} if infos is None else infos
with open(info_path, 'w') as fout:
data = json.dumps(infos)
fout.write(data)

@ -38,21 +38,23 @@ def remove_duplicates_and_blank(hyp: List[int], blank_id=0) -> List[int]:
new_hyp: List[int] = [] new_hyp: List[int] = []
cur = 0 cur = 0
while cur < len(hyp): while cur < len(hyp):
# add non-blank into new_hyp
if hyp[cur] != blank_id: if hyp[cur] != blank_id:
new_hyp.append(hyp[cur]) new_hyp.append(hyp[cur])
# skip repeat label
prev = cur prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]: while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1 cur += 1
return new_hyp return new_hyp
def insert_blank(label: np.ndarray, blank_id: int=0): def insert_blank(label: np.ndarray, blank_id: int=0) -> np.ndarray:
"""Insert blank token between every two label token. """Insert blank token between every two label token.
"abcdefg" -> "-a-b-c-d-e-f-g-" "abcdefg" -> "-a-b-c-d-e-f-g-"
Args: Args:
label ([np.ndarray]): label ids, (L). label ([np.ndarray]): label ids, List[int], (L).
blank_id (int, optional): blank id. Defaults to 0. blank_id (int, optional): blank id. Defaults to 0.
Returns: Returns:
@ -61,13 +63,13 @@ def insert_blank(label: np.ndarray, blank_id: int=0):
label = np.expand_dims(label, 1) #[L, 1] label = np.expand_dims(label, 1) #[L, 1]
blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id
label = np.concatenate([blanks, label], axis=1) #[L, 2] label = np.concatenate([blanks, label], axis=1) #[L, 2]
label = label.reshape(-1) #[2L] label = label.reshape(-1) #[2L], -l-l-l
label = np.append(label, label[0]) #[2L + 1] label = np.append(label, label[0]) #[2L + 1], -l-l-l-
return label return label
def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
blank_id=0) -> list: blank_id=0) -> List[int]:
"""ctc forced alignment. """ctc forced alignment.
https://distill.pub/2017/ctc/ https://distill.pub/2017/ctc/
@ -77,23 +79,25 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
y (paddle.Tensor): label id sequence tensor, 1d tensor (L) y (paddle.Tensor): label id sequence tensor, 1d tensor (L)
blank_id (int): blank symbol index blank_id (int): blank symbol index
Returns: Returns:
paddle.Tensor: best alignment result, (T). List[int]: best alignment result, (T).
""" """
y_insert_blank = insert_blank(y, blank_id) y_insert_blank = insert_blank(y, blank_id) #(2L+1)
log_alpha = paddle.zeros( log_alpha = paddle.zeros(
(ctc_probs.size(0), len(y_insert_blank))) #(T, 2L+1) (ctc_probs.size(0), len(y_insert_blank))) #(T, 2L+1)
log_alpha = log_alpha - float('inf') # log of zero log_alpha = log_alpha - float('inf') # log of zero
# TODO(Hui Zhang): zeros not support paddle.int16
state_path = (paddle.zeros( state_path = (paddle.zeros(
(ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int16) - 1 (ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int32) - 1
) # state path ) # state path, Tuple((T, 2L+1))
# init start state # init start state
log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] # Sb # TODO(Hui Zhang): VarBase.__getitem__() not support np.int64
log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] # Snb log_alpha[0, 0] = ctc_probs[0][int(y_insert_blank[0])] # State-b, Sb
log_alpha[0, 1] = ctc_probs[0][int(y_insert_blank[1])] # State-nb, Snb
for t in range(1, ctc_probs.size(0)): for t in range(1, ctc_probs.size(0)): # T
for s in range(len(y_insert_blank)): for s in range(len(y_insert_blank)): # 2L+1
if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[ if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[
s] == y_insert_blank[s - 2]: s] == y_insert_blank[s - 2]:
candidates = paddle.to_tensor( candidates = paddle.to_tensor(
@ -106,11 +110,13 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
log_alpha[t - 1, s - 2], log_alpha[t - 1, s - 2],
]) ])
prev_state = [s, s - 1, s - 2] prev_state = [s, s - 1, s - 2]
log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][ # TODO(Hui Zhang): VarBase.__getitem__() not support np.int64
y_insert_blank[s]] log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][int(
y_insert_blank[s])]
state_path[t, s] = prev_state[paddle.argmax(candidates)] state_path[t, s] = prev_state[paddle.argmax(candidates)]
state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int16) # TODO(Hui Zhang): zeros not support paddle.int16
state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int32)
candidates = paddle.to_tensor([ candidates = paddle.to_tensor([
log_alpha[-1, len(y_insert_blank) - 1], # Sb log_alpha[-1, len(y_insert_blank) - 1], # Sb

@ -0,0 +1,127 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from typing import List
from typing import Text
import textgrid
def segment_alignment(alignment: List[int], blank_id=0) -> List[List[int]]:
"""segment ctc alignment ids by continuous blank and repeat label.
Args:
alignment (List[int]): ctc alignment id sequence.
e.g. [0, 0, 0, 1, 1, 1, 2, 0, 0, 3]
blank_id (int, optional): blank id. Defaults to 0.
Returns:
List[List[int]]: token align, segment aligment id sequence.
e.g. [[0, 0, 0, 1, 1, 1], [2], [0, 0, 3]]
"""
# convert alignment to a praat format, which is a doing phonetics
# by computer and helps analyzing alignment
align_segs = []
# get frames level duration for each token
start = 0
end = 0
while end < len(alignment):
while end < len(alignment) and alignment[end] == blank_id: # blank
end += 1
if end == len(alignment):
align_segs[-1].extend(alignment[start:])
break
end += 1
while end < len(alignment) and alignment[end - 1] == alignment[
end]: # repeat label
end += 1
align_segs.append(alignment[start:end])
start = end
return align_segs
def align_to_tierformat(align_segs: List[List[int]],
subsample: int,
token_dict: Dict[int, Text],
blank_id=0) -> List[Text]:
"""Generate textgrid.Interval format from alignment segmentations.
Args:
align_segs (List[List[int]]): segmented ctc alignment ids.
subsample (int): 25ms frame_length, 10ms hop_length, 1/subsample
token_dict (Dict[int, Text]): int -> str map.
Returns:
List[Text]: list of textgrid.Interval text, str(start, end, text).
"""
hop_length = 10 # ms
second_ms = 1000 # ms
frame_per_second = second_ms / hop_length # 25ms frame_length, 10ms hop_length
second_per_frame = 1.0 / frame_per_second
begin = 0
duration = 0
tierformat = []
for idx, tokens in enumerate(align_segs):
token_len = len(tokens)
token = tokens[-1]
# time duration in second
duration = token_len * subsample * second_per_frame
if idx < len(align_segs) - 1:
print(f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}")
tierformat.append(
f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}\n")
else:
for i in tokens:
if i != blank_id:
token = i
break
print(f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}")
tierformat.append(
f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}\n")
begin = begin + duration
return tierformat
def generate_textgrid(maxtime: float,
intervals: List[Text],
output: Text,
name: Text='ali') -> None:
"""Create alignment textgrid file.
Args:
maxtime (float): audio duartion.
intervals (List[Text]): ctc output alignment. e.g. "start-time end-time word" per item.
output (Text): textgrid filepath.
name (Text, optional): tier or layer name. Defaults to 'ali'.
"""
# Download Praat: https://www.fon.hum.uva.nl/praat/
avg_interval = maxtime / (len(intervals) + 1)
print(f"average second/token: {avg_interval}")
margin = 0.0001
tg = textgrid.TextGrid(maxTime=maxtime)
tier = textgrid.IntervalTier(name=name, maxTime=maxtime)
i = 0
for dur in intervals:
s, e, text = dur.split()
tier.add(minTime=float(s) + margin, maxTime=float(e), mark=text)
tg.append(tier)
tg.write(output)
print("successfully generator textgrid {}.".format(output))

@ -79,3 +79,22 @@ def log_add(args: List[int]) -> float:
a_max = max(args) a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max) for a in args)) lsp = math.log(sum(math.exp(a - a_max) for a in args))
return a_max + lsp return a_max + lsp
def get_subsample(config):
"""Subsample rate from config.
Args:
config (yacs.config.CfgNode): yaml config
Returns:
int: subsample rate.
"""
input_layer = config["model"]["encoder_conf"]["input_layer"]
assert input_layer in ["conv2d", "conv2d6", "conv2d8"]
if input_layer == "conv2d":
return 4
elif input_layer == "conv2d6":
return 6
elif input_layer == "conv2d8":
return 8

@ -48,6 +48,9 @@ training:
weight_decay: 1e-06 weight_decay: 1e-06
global_grad_clip: 3.0 global_grad_clip: 3.0
log_interval: 100 log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding: decoding:
batch_size: 128 batch_size: 128

@ -93,6 +93,9 @@ training:
warmup_steps: 25000 warmup_steps: 25000
lr_decay: 1.0 lr_decay: 1.0
log_interval: 100 log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding: decoding:

@ -88,6 +88,9 @@ training:
warmup_steps: 25000 warmup_steps: 25000
lr_decay: 1.0 lr_decay: 1.0
log_interval: 100 log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding: decoding:

@ -0,0 +1,43 @@
#! /usr/bin/env bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ngpu == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
ckpt_name=$(basename ${ckpt_prefxi})
mkdir -p exp
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/alignment.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in ctc alignment!"
exit 1
fi
exit 0

@ -30,10 +30,15 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n # test ckpt avg_n
CUDA_VISIBLE_DEVICES=4 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# ctc alignment of test data
CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# export ckpt avg_n # export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi fi

@ -48,6 +48,9 @@ training:
weight_decay: 1e-06 weight_decay: 1e-06
global_grad_clip: 5.0 global_grad_clip: 5.0
log_interval: 100 log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding: decoding:
batch_size: 128 batch_size: 128

@ -93,6 +93,9 @@ training:
warmup_steps: 25000 warmup_steps: 25000
lr_decay: 1.0 lr_decay: 1.0
log_interval: 100 log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding: decoding:

@ -86,6 +86,9 @@ training:
warmup_steps: 25000 warmup_steps: 25000
lr_decay: 1.0 lr_decay: 1.0
log_interval: 100 log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding: decoding:

@ -89,6 +89,9 @@ training:
warmup_steps: 25000 warmup_steps: 25000
lr_decay: 1.0 lr_decay: 1.0
log_interval: 100 log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding: decoding:

@ -84,6 +84,9 @@ training:
warmup_steps: 25000 warmup_steps: 25000
lr_decay: 1.0 lr_decay: 1.0
log_interval: 100 log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding: decoding:

@ -0,0 +1,43 @@
#! /usr/bin/env bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ngpu == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
ckpt_name=$(basename ${ckpt_prefxi})
mkdir -p exp
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/alignment.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in ctc alignment!"
exit 1
fi
exit 0

@ -33,6 +33,11 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
fi fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# ctc alignment of test data
CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# export ckpt avg_n # export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi fi

@ -43,12 +43,16 @@ model:
share_rnn_weights: True share_rnn_weights: True
training: training:
n_epoch: 24 n_epoch: 10
lr: 1e-5 lr: 1e-5
lr_decay: 1.0 lr_decay: 1.0
weight_decay: 1e-06 weight_decay: 1e-06
global_grad_clip: 5.0 global_grad_clip: 5.0
log_interval: 1 log_interval: 1
checkpoint:
kbest_n: 3
latest_n: 2
decoding: decoding:
batch_size: 128 batch_size: 128

@ -91,6 +91,9 @@ training:
warmup_steps: 25000 warmup_steps: 25000
lr_decay: 1.0 lr_decay: 1.0
log_interval: 1 log_interval: 1
checkpoint:
kbest_n: 10
latest_n: 1
decoding: decoding:

@ -84,6 +84,9 @@ training:
warmup_steps: 25000 warmup_steps: 25000
lr_decay: 1.0 lr_decay: 1.0
log_interval: 1 log_interval: 1
checkpoint:
kbest_n: 10
latest_n: 1
decoding: decoding:

@ -87,6 +87,9 @@ training:
warmup_steps: 25000 warmup_steps: 25000
lr_decay: 1.0 lr_decay: 1.0
log_interval: 1 log_interval: 1
checkpoint:
kbest_n: 10
latest_n: 1
decoding: decoding:

@ -84,6 +84,9 @@ training:
warmup_steps: 25000 warmup_steps: 25000
lr_decay: 1.0 lr_decay: 1.0
log_interval: 1 log_interval: 1
checkpoint:
kbest_n: 10
latest_n: 1
decoding: decoding:

@ -0,0 +1,43 @@
#! /usr/bin/env bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ngpu == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
ckpt_name=$(basename ${ckpt_prefxi})
mkdir -p exp
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/alignment.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in ctc alignment!"
exit 1
fi
exit 0

@ -34,6 +34,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
fi fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# ctc alignment of test data
CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# export ckpt avg_n # export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi fi

@ -38,4 +38,4 @@ soxbindings.done:
mfa.done: mfa.done:
test -d montreal-forced-aligner || wget https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner/releases/download/v1.0.1/montreal-forced-aligner_linux.tar.gz test -d montreal-forced-aligner || wget https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner/releases/download/v1.0.1/montreal-forced-aligner_linux.tar.gz
tar xvf montreal-forced-aligner_linux.tar.gz tar xvf montreal-forced-aligner_linux.tar.gz
touch mfa.done touch mfa.done

Loading…
Cancel
Save