diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 05a37b21b..dd62f537e 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -34,9 +34,12 @@ from deepspeech.models.u2 import U2Model from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.scheduler import WarmupLR from deepspeech.training.trainer import Trainer +from deepspeech.utils import ctc_utils from deepspeech.utils import error_rate from deepspeech.utils import layer_tools from deepspeech.utils import mp_tools +from deepspeech.utils import text_grid +from deepspeech.utils import utility from deepspeech.utils.log import Log logger = Log(__name__).getlog() @@ -278,7 +281,15 @@ class U2Trainer(Trainer): shuffle=False, drop_last=False, collate_fn=SpeechCollator.from_config(config)) - logger.info("Setup train/valid/test Dataloader!") + # return text token id + config.collator.keep_transcription_text = False + self.align_loader = DataLoader( + test_dataset, + batch_size=config.decoding.batch_size, + shuffle=False, + drop_last=False, + collate_fn=SpeechCollator.from_config(config)) + logger.info("Setup train/valid/test/align Dataloader!") def setup_model(self): config = self.config @@ -353,7 +364,7 @@ class U2Tester(U2Trainer): decoding_chunk_size=-1, # decoding chunk size. Defaults to -1. # <0: for decoding, use full chunk. # >0: for decoding, use fixed chunk size as set. - # 0: used for training, it's prohibited here. + # 0: used for training, it's prohibited here. num_decoding_left_chunks=-1, # number of left chunks for decoding. Defaults to -1. simulate_streaming=False, # simulate streaming inference. Defaults to False. )) @@ -498,6 +509,73 @@ class U2Tester(U2Trainer): except KeyboardInterrupt: sys.exit(-1) + @paddle.no_grad() + def align(self): + if self.config.decoding.batch_size > 1: + logger.fatal('alignment mode must be running with batch_size == 1') + sys.exit(1) + + # xxx.align + assert self.args.result_file and self.args.result_file.endswith( + '.align') + + self.model.eval() + logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}") + + stride_ms = self.align_loader.collate_fn.stride_ms + token_dict = self.align_loader.collate_fn.vocab_list + with open(self.args.result_file, 'w') as fout: + # one example in batch + for i, batch in enumerate(self.align_loader): + key, feat, feats_length, target, target_length = batch + + # 1. Encoder + encoder_out, encoder_mask = self.model._forward_encoder( + feat, feats_length) # (B, maxlen, encoder_dim) + maxlen = encoder_out.size(1) + ctc_probs = self.model.ctc.log_softmax( + encoder_out) # (1, maxlen, vocab_size) + + # 2. alignment + ctc_probs = ctc_probs.squeeze(0) + target = target.squeeze(0) + alignment = ctc_utils.forced_align(ctc_probs, target) + logger.info("align ids", key[0], alignment) + fout.write('{} {}\n'.format(key[0], alignment)) + + # 3. gen praat + # segment alignment + align_segs = text_grid.segment_alignment(alignment) + logger.info("align tokens", key[0], align_segs) + # IntervalTier, List["start end token\n"] + subsample = utility.get_subsample(self.config) + tierformat = text_grid.align_to_tierformat( + align_segs, subsample, token_dict) + # write tier + align_output_path = os.path.join( + os.path.dirname(self.args.result_file), "align") + tier_path = os.path.join(align_output_path, key[0] + ".tier") + with open(tier_path, 'w') as f: + f.writelines(tierformat) + # write textgrid + textgrid_path = os.path.join(align_output_path, + key[0] + ".TextGrid") + second_per_frame = 1. / (1000. / + stride_ms) # 25ms window, 10ms stride + second_per_example = ( + len(alignment) + 1) * subsample * second_per_frame + text_grid.generate_textgrid( + maxtime=second_per_example, + intervals=tierformat, + output=textgrid_path) + + def run_align(self): + self.resume_or_scratch() + try: + self.align() + except KeyboardInterrupt: + sys.exit(-1) + def load_inferspec(self): """infer model and input spec. diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 56de32617..5ebba1a98 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -18,8 +18,8 @@ import paddle from paddle import distributed as dist from tensorboardX import SummaryWriter -from deepspeech.utils import checkpoint from deepspeech.utils import mp_tools +from deepspeech.utils.checkpoint import Checkpoint from deepspeech.utils.log import Log __all__ = ["Trainer"] @@ -139,9 +139,9 @@ class Trainer(): "epoch": self.epoch, "lr": self.optimizer.get_lr() }) - checkpoint.save_parameters(self.checkpoint_dir, self.iteration - if tag is None else tag, self.model, - self.optimizer, infos) + self.checkpoint.add_checkpoint(self.checkpoint_dir, self.iteration + if tag is None else tag, self.model, + self.optimizer, infos) def resume_or_scratch(self): """Resume from latest checkpoint at checkpoints in the output @@ -151,7 +151,7 @@ class Trainer(): resume training. """ scratch = None - infos = checkpoint.load_parameters( + infos = self.checkpoint.load_latest_parameters( self.model, self.optimizer, checkpoint_dir=self.checkpoint_dir, @@ -180,7 +180,7 @@ class Trainer(): from_scratch = self.resume_or_scratch() if from_scratch: # save init model, i.e. 0 epoch - self.save(tag='init') + self.save(tag='init', infos=None) self.lr_scheduler.step(self.iteration) if self.parallel: @@ -263,6 +263,10 @@ class Trainer(): self.checkpoint_dir = checkpoint_dir + self.checkpoint = Checkpoint( + kbest_n=self.config.training.checkpoint.kbest_n, + latest_n=self.config.training.checkpoint.latest_n) + @mp_tools.rank_zero_only def destory(self): """Close visualizer to avoid hanging after training""" diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index 8ede6b8fd..8c5d8d605 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import glob import json import os import re +from pathlib import Path from typing import Union import paddle @@ -25,128 +27,260 @@ from deepspeech.utils.log import Log logger = Log(__name__).getlog() -__all__ = ["load_parameters", "save_parameters"] - - -def _load_latest_checkpoint(checkpoint_dir: str) -> int: - """Get the iteration number corresponding to the latest saved checkpoint. - Args: - checkpoint_dir (str): the directory where checkpoint is saved. - Returns: - int: the latest iteration number. -1 for no checkpoint to load. - """ - checkpoint_record = os.path.join(checkpoint_dir, "checkpoint") - if not os.path.isfile(checkpoint_record): - return -1 - - # Fetch the latest checkpoint index. - with open(checkpoint_record, "rt") as handle: - latest_checkpoint = handle.readlines()[-1].strip() - iteration = int(latest_checkpoint.split(":")[-1]) - return iteration - - -def _save_record(checkpoint_dir: str, iteration: int): - """Save the iteration number of the latest model to be checkpoint record. - Args: - checkpoint_dir (str): the directory where checkpoint is saved. - iteration (int): the latest iteration number. - Returns: - None - """ - checkpoint_record = os.path.join(checkpoint_dir, "checkpoint") - # Update the latest checkpoint index. - with open(checkpoint_record, "a+") as handle: - handle.write("model_checkpoint_path:{}\n".format(iteration)) - - -def load_parameters(model, - optimizer=None, - checkpoint_dir=None, - checkpoint_path=None): - """Load a specific model checkpoint from disk. - Args: - model (Layer): model to load parameters. - optimizer (Optimizer, optional): optimizer to load states if needed. - Defaults to None. - checkpoint_dir (str, optional): the directory where checkpoint is saved. - checkpoint_path (str, optional): if specified, load the checkpoint - stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will - be ignored. Defaults to None. - Returns: - configs (dict): epoch or step, lr and other meta info should be saved. - """ - configs = {} - - if checkpoint_path is not None: - tag = os.path.basename(checkpoint_path).split(":")[-1] - elif checkpoint_dir is not None: - iteration = _load_latest_checkpoint(checkpoint_dir) - if iteration == -1: - return configs - checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration)) - else: - raise ValueError( - "At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!" - ) - - rank = dist.get_rank() - - params_path = checkpoint_path + ".pdparams" - model_dict = paddle.load(params_path) - model.set_state_dict(model_dict) - logger.info("Rank {}: loaded model from {}".format(rank, params_path)) - - optimizer_path = checkpoint_path + ".pdopt" - if optimizer and os.path.isfile(optimizer_path): - optimizer_dict = paddle.load(optimizer_path) - optimizer.set_state_dict(optimizer_dict) - logger.info("Rank {}: loaded optimizer state from {}".format( - rank, optimizer_path)) - - info_path = re.sub('.pdparams$', '.json', params_path) - if os.path.exists(info_path): - with open(info_path, 'r') as fin: - configs = json.load(fin) - return configs - - -@mp_tools.rank_zero_only -def save_parameters(checkpoint_dir: str, - tag_or_iteration: Union[int, str], - model: paddle.nn.Layer, - optimizer: Optimizer=None, - infos: dict=None): - """Checkpoint the latest trained model parameters. - Args: - checkpoint_dir (str): the directory where checkpoint is saved. - tag_or_iteration (int or str): the latest iteration(step or epoch) number. - model (Layer): model to be checkpointed. - optimizer (Optimizer, optional): optimizer to be checkpointed. - Defaults to None. - infos (dict or None): any info you want to save. - Returns: - None - """ - checkpoint_path = os.path.join(checkpoint_dir, - "{}".format(tag_or_iteration)) - - model_dict = model.state_dict() - params_path = checkpoint_path + ".pdparams" - paddle.save(model_dict, params_path) - logger.info("Saved model to {}".format(params_path)) - - if optimizer: - opt_dict = optimizer.state_dict() +__all__ = ["Checkpoint"] + + +class Checkpoint(object): + def __init__(self, kbest_n: int=5, latest_n: int=1): + self.best_records: Mapping[Path, float] = {} + self.latest_records = [] + self.kbest_n = kbest_n + self.latest_n = latest_n + self._save_all = (kbest_n == -1) + + def add_checkpoint(self, + checkpoint_dir, + tag_or_iteration, + model, + optimizer, + infos, + metric_type="val_loss"): + if (metric_type not in infos.keys()): + self._save_parameters(checkpoint_dir, tag_or_iteration, model, + optimizer, infos) + return + + #save best + if self._should_save_best(infos[metric_type]): + self._save_best_checkpoint_and_update( + infos[metric_type], checkpoint_dir, tag_or_iteration, model, + optimizer, infos) + #save latest + self._save_latest_checkpoint_and_update( + checkpoint_dir, tag_or_iteration, model, optimizer, infos) + + if isinstance(tag_or_iteration, int): + self._save_checkpoint_record(checkpoint_dir, tag_or_iteration) + + def load_latest_parameters(self, + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): + """Load a last model checkpoint from disk. + Args: + model (Layer): model to load parameters. + optimizer (Optimizer, optional): optimizer to load states if needed. + Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will + be ignored. Defaults to None. + Returns: + configs (dict): epoch or step, lr and other meta info should be saved. + """ + return self._load_parameters(model, optimizer, checkpoint_dir, + checkpoint_path, "checkpoint_latest") + + def load_best_parameters(self, + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): + """Load a last model checkpoint from disk. + Args: + model (Layer): model to load parameters. + optimizer (Optimizer, optional): optimizer to load states if needed. + Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will + be ignored. Defaults to None. + Returns: + configs (dict): epoch or step, lr and other meta info should be saved. + """ + return self._load_parameters(model, optimizer, checkpoint_dir, + checkpoint_path, "checkpoint_best") + + def _should_save_best(self, metric: float) -> bool: + if not self._best_full(): + return True + + # already full + worst_record_path = max(self.best_records, key=self.best_records.get) + # worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0] + worst_metric = self.best_records[worst_record_path] + return metric < worst_metric + + def _best_full(self): + return (not self._save_all) and len(self.best_records) == self.kbest_n + + def _latest_full(self): + return len(self.latest_records) == self.latest_n + + def _save_best_checkpoint_and_update(self, metric, checkpoint_dir, + tag_or_iteration, model, optimizer, + infos): + # remove the worst + if self._best_full(): + worst_record_path = max(self.best_records, + key=self.best_records.get) + self.best_records.pop(worst_record_path) + if (worst_record_path not in self.latest_records): + logger.info( + "remove the worst checkpoint: {}".format(worst_record_path)) + self._del_checkpoint(checkpoint_dir, worst_record_path) + + # add the new one + self._save_parameters(checkpoint_dir, tag_or_iteration, model, + optimizer, infos) + self.best_records[tag_or_iteration] = metric + + def _save_latest_checkpoint_and_update( + self, checkpoint_dir, tag_or_iteration, model, optimizer, infos): + # remove the old + if self._latest_full(): + to_del_fn = self.latest_records.pop(0) + if (to_del_fn not in self.best_records.keys()): + logger.info( + "remove the latest checkpoint: {}".format(to_del_fn)) + self._del_checkpoint(checkpoint_dir, to_del_fn) + self.latest_records.append(tag_or_iteration) + + self._save_parameters(checkpoint_dir, tag_or_iteration, model, + optimizer, infos) + + def _del_checkpoint(self, checkpoint_dir, tag_or_iteration): + checkpoint_path = os.path.join(checkpoint_dir, + "{}".format(tag_or_iteration)) + for filename in glob.glob(checkpoint_path + ".*"): + os.remove(filename) + logger.info("delete file: {}".format(filename)) + + def _load_checkpoint_idx(self, checkpoint_record: str) -> int: + """Get the iteration number corresponding to the latest saved checkpoint. + Args: + checkpoint_path (str): the saved path of checkpoint. + Returns: + int: the latest iteration number. -1 for no checkpoint to load. + """ + if not os.path.isfile(checkpoint_record): + return -1 + + # Fetch the latest checkpoint index. + with open(checkpoint_record, "rt") as handle: + latest_checkpoint = handle.readlines()[-1].strip() + iteration = int(latest_checkpoint.split(":")[-1]) + return iteration + + def _save_checkpoint_record(self, checkpoint_dir: str, iteration: int): + """Save the iteration number of the latest model to be checkpoint record. + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + iteration (int): the latest iteration number. + Returns: + None + """ + checkpoint_record_latest = os.path.join(checkpoint_dir, + "checkpoint_latest") + checkpoint_record_best = os.path.join(checkpoint_dir, "checkpoint_best") + + with open(checkpoint_record_best, "w") as handle: + for i in self.best_records.keys(): + handle.write("model_checkpoint_path:{}\n".format(i)) + with open(checkpoint_record_latest, "w") as handle: + for i in self.latest_records: + handle.write("model_checkpoint_path:{}\n".format(i)) + + def _load_parameters(self, + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None, + checkpoint_file=None): + """Load a last model checkpoint from disk. + Args: + model (Layer): model to load parameters. + optimizer (Optimizer, optional): optimizer to load states if needed. + Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will + be ignored. Defaults to None. + checkpoint_file "checkpoint_latest" or "checkpoint_best" + Returns: + configs (dict): epoch or step, lr and other meta info should be saved. + """ + configs = {} + + if checkpoint_path is not None: + tag = os.path.basename(checkpoint_path).split(":")[-1] + elif checkpoint_dir is not None and checkpoint_file is not None: + checkpoint_record = os.path.join(checkpoint_dir, checkpoint_file) + iteration = self._load_checkpoint_idx(checkpoint_record) + if iteration == -1: + return configs + checkpoint_path = os.path.join(checkpoint_dir, + "{}".format(iteration)) + else: + raise ValueError( + "At least one of 'checkpoint_dir' and 'checkpoint_file' and 'checkpoint_path' should be specified!" + ) + + rank = dist.get_rank() + + params_path = checkpoint_path + ".pdparams" + model_dict = paddle.load(params_path) + model.set_state_dict(model_dict) + logger.info("Rank {}: loaded model from {}".format(rank, params_path)) + optimizer_path = checkpoint_path + ".pdopt" - paddle.save(opt_dict, optimizer_path) - logger.info("Saved optimzier state to {}".format(optimizer_path)) + if optimizer and os.path.isfile(optimizer_path): + optimizer_dict = paddle.load(optimizer_path) + optimizer.set_state_dict(optimizer_dict) + logger.info("Rank {}: loaded optimizer state from {}".format( + rank, optimizer_path)) + + info_path = re.sub('.pdparams$', '.json', params_path) + if os.path.exists(info_path): + with open(info_path, 'r') as fin: + configs = json.load(fin) + return configs + + @mp_tools.rank_zero_only + def _save_parameters(self, + checkpoint_dir: str, + tag_or_iteration: Union[int, str], + model: paddle.nn.Layer, + optimizer: Optimizer=None, + infos: dict=None): + """Checkpoint the latest trained model parameters. + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + tag_or_iteration (int or str): the latest iteration(step or epoch) number. + model (Layer): model to be checkpointed. + optimizer (Optimizer, optional): optimizer to be checkpointed. + Defaults to None. + infos (dict or None): any info you want to save. + Returns: + None + """ + checkpoint_path = os.path.join(checkpoint_dir, + "{}".format(tag_or_iteration)) + + model_dict = model.state_dict() + params_path = checkpoint_path + ".pdparams" + paddle.save(model_dict, params_path) + logger.info("Saved model to {}".format(params_path)) - info_path = re.sub('.pdparams$', '.json', params_path) - infos = {} if infos is None else infos - with open(info_path, 'w') as fout: - data = json.dumps(infos) - fout.write(data) + if optimizer: + opt_dict = optimizer.state_dict() + optimizer_path = checkpoint_path + ".pdopt" + paddle.save(opt_dict, optimizer_path) + logger.info("Saved optimzier state to {}".format(optimizer_path)) - if isinstance(tag_or_iteration, int): - _save_record(checkpoint_dir, tag_or_iteration) + info_path = re.sub('.pdparams$', '.json', params_path) + infos = {} if infos is None else infos + with open(info_path, 'w') as fout: + data = json.dumps(infos) + fout.write(data) diff --git a/deepspeech/utils/ctc_utils.py b/deepspeech/utils/ctc_utils.py index 73669fea6..09543d48d 100644 --- a/deepspeech/utils/ctc_utils.py +++ b/deepspeech/utils/ctc_utils.py @@ -38,21 +38,23 @@ def remove_duplicates_and_blank(hyp: List[int], blank_id=0) -> List[int]: new_hyp: List[int] = [] cur = 0 while cur < len(hyp): + # add non-blank into new_hyp if hyp[cur] != blank_id: new_hyp.append(hyp[cur]) + # skip repeat label prev = cur while cur < len(hyp) and hyp[cur] == hyp[prev]: cur += 1 return new_hyp -def insert_blank(label: np.ndarray, blank_id: int=0): +def insert_blank(label: np.ndarray, blank_id: int=0) -> np.ndarray: """Insert blank token between every two label token. "abcdefg" -> "-a-b-c-d-e-f-g-" Args: - label ([np.ndarray]): label ids, (L). + label ([np.ndarray]): label ids, List[int], (L). blank_id (int, optional): blank id. Defaults to 0. Returns: @@ -61,13 +63,13 @@ def insert_blank(label: np.ndarray, blank_id: int=0): label = np.expand_dims(label, 1) #[L, 1] blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id label = np.concatenate([blanks, label], axis=1) #[L, 2] - label = label.reshape(-1) #[2L] - label = np.append(label, label[0]) #[2L + 1] + label = label.reshape(-1) #[2L], -l-l-l + label = np.append(label, label[0]) #[2L + 1], -l-l-l- return label def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, - blank_id=0) -> list: + blank_id=0) -> List[int]: """ctc forced alignment. https://distill.pub/2017/ctc/ @@ -77,23 +79,25 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, y (paddle.Tensor): label id sequence tensor, 1d tensor (L) blank_id (int): blank symbol index Returns: - paddle.Tensor: best alignment result, (T). + List[int]: best alignment result, (T). """ - y_insert_blank = insert_blank(y, blank_id) + y_insert_blank = insert_blank(y, blank_id) #(2L+1) log_alpha = paddle.zeros( (ctc_probs.size(0), len(y_insert_blank))) #(T, 2L+1) log_alpha = log_alpha - float('inf') # log of zero + # TODO(Hui Zhang): zeros not support paddle.int16 state_path = (paddle.zeros( - (ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int16) - 1 - ) # state path + (ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int32) - 1 + ) # state path, Tuple((T, 2L+1)) # init start state - log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] # Sb - log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] # Snb + # TODO(Hui Zhang): VarBase.__getitem__() not support np.int64 + log_alpha[0, 0] = ctc_probs[0][int(y_insert_blank[0])] # State-b, Sb + log_alpha[0, 1] = ctc_probs[0][int(y_insert_blank[1])] # State-nb, Snb - for t in range(1, ctc_probs.size(0)): - for s in range(len(y_insert_blank)): + for t in range(1, ctc_probs.size(0)): # T + for s in range(len(y_insert_blank)): # 2L+1 if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[ s] == y_insert_blank[s - 2]: candidates = paddle.to_tensor( @@ -106,11 +110,13 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, log_alpha[t - 1, s - 2], ]) prev_state = [s, s - 1, s - 2] - log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][ - y_insert_blank[s]] + # TODO(Hui Zhang): VarBase.__getitem__() not support np.int64 + log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][int( + y_insert_blank[s])] state_path[t, s] = prev_state[paddle.argmax(candidates)] - state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int16) + # TODO(Hui Zhang): zeros not support paddle.int16 + state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int32) candidates = paddle.to_tensor([ log_alpha[-1, len(y_insert_blank) - 1], # Sb diff --git a/deepspeech/utils/text_grid.py b/deepspeech/utils/text_grid.py new file mode 100644 index 000000000..3af58c9ba --- /dev/null +++ b/deepspeech/utils/text_grid.py @@ -0,0 +1,127 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict +from typing import List +from typing import Text + +import textgrid + + +def segment_alignment(alignment: List[int], blank_id=0) -> List[List[int]]: + """segment ctc alignment ids by continuous blank and repeat label. + + Args: + alignment (List[int]): ctc alignment id sequence. + e.g. [0, 0, 0, 1, 1, 1, 2, 0, 0, 3] + blank_id (int, optional): blank id. Defaults to 0. + + Returns: + List[List[int]]: token align, segment aligment id sequence. + e.g. [[0, 0, 0, 1, 1, 1], [2], [0, 0, 3]] + """ + # convert alignment to a praat format, which is a doing phonetics + # by computer and helps analyzing alignment + align_segs = [] + # get frames level duration for each token + start = 0 + end = 0 + while end < len(alignment): + while end < len(alignment) and alignment[end] == blank_id: # blank + end += 1 + if end == len(alignment): + align_segs[-1].extend(alignment[start:]) + break + end += 1 + while end < len(alignment) and alignment[end - 1] == alignment[ + end]: # repeat label + end += 1 + align_segs.append(alignment[start:end]) + start = end + return align_segs + + +def align_to_tierformat(align_segs: List[List[int]], + subsample: int, + token_dict: Dict[int, Text], + blank_id=0) -> List[Text]: + """Generate textgrid.Interval format from alignment segmentations. + + Args: + align_segs (List[List[int]]): segmented ctc alignment ids. + subsample (int): 25ms frame_length, 10ms hop_length, 1/subsample + token_dict (Dict[int, Text]): int -> str map. + + Returns: + List[Text]: list of textgrid.Interval text, str(start, end, text). + """ + hop_length = 10 # ms + second_ms = 1000 # ms + frame_per_second = second_ms / hop_length # 25ms frame_length, 10ms hop_length + second_per_frame = 1.0 / frame_per_second + + begin = 0 + duration = 0 + tierformat = [] + + for idx, tokens in enumerate(align_segs): + token_len = len(tokens) + token = tokens[-1] + # time duration in second + duration = token_len * subsample * second_per_frame + if idx < len(align_segs) - 1: + print(f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}") + tierformat.append( + f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}\n") + else: + for i in tokens: + if i != blank_id: + token = i + break + print(f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}") + tierformat.append( + f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}\n") + begin = begin + duration + + return tierformat + + +def generate_textgrid(maxtime: float, + intervals: List[Text], + output: Text, + name: Text='ali') -> None: + """Create alignment textgrid file. + + Args: + maxtime (float): audio duartion. + intervals (List[Text]): ctc output alignment. e.g. "start-time end-time word" per item. + output (Text): textgrid filepath. + name (Text, optional): tier or layer name. Defaults to 'ali'. + """ + # Download Praat: https://www.fon.hum.uva.nl/praat/ + avg_interval = maxtime / (len(intervals) + 1) + print(f"average second/token: {avg_interval}") + margin = 0.0001 + + tg = textgrid.TextGrid(maxTime=maxtime) + tier = textgrid.IntervalTier(name=name, maxTime=maxtime) + + i = 0 + for dur in intervals: + s, e, text = dur.split() + tier.add(minTime=float(s) + margin, maxTime=float(e), mark=text) + + tg.append(tier) + + tg.write(output) + print("successfully generator textgrid {}.".format(output)) diff --git a/deepspeech/utils/utility.py b/deepspeech/utils/utility.py index 64570026b..a0639e065 100644 --- a/deepspeech/utils/utility.py +++ b/deepspeech/utils/utility.py @@ -79,3 +79,22 @@ def log_add(args: List[int]) -> float: a_max = max(args) lsp = math.log(sum(math.exp(a - a_max) for a in args)) return a_max + lsp + + +def get_subsample(config): + """Subsample rate from config. + + Args: + config (yacs.config.CfgNode): yaml config + + Returns: + int: subsample rate. + """ + input_layer = config["model"]["encoder_conf"]["input_layer"] + assert input_layer in ["conv2d", "conv2d6", "conv2d8"] + if input_layer == "conv2d": + return 4 + elif input_layer == "conv2d6": + return 6 + elif input_layer == "conv2d8": + return 8 diff --git a/examples/aishell/s0/conf/deepspeech2.yaml b/examples/aishell/s0/conf/deepspeech2.yaml index 1004fde0e..1c97fc607 100644 --- a/examples/aishell/s0/conf/deepspeech2.yaml +++ b/examples/aishell/s0/conf/deepspeech2.yaml @@ -48,6 +48,9 @@ training: weight_decay: 1e-06 global_grad_clip: 3.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: batch_size: 128 diff --git a/examples/aishell/s1/conf/chunk_conformer.yaml b/examples/aishell/s1/conf/chunk_conformer.yaml index 0e5b8699f..3e606788e 100644 --- a/examples/aishell/s1/conf/chunk_conformer.yaml +++ b/examples/aishell/s1/conf/chunk_conformer.yaml @@ -93,6 +93,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/aishell/s1/conf/conformer.yaml b/examples/aishell/s1/conf/conformer.yaml index 116c91927..4b1430c58 100644 --- a/examples/aishell/s1/conf/conformer.yaml +++ b/examples/aishell/s1/conf/conformer.yaml @@ -88,6 +88,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/aishell/s1/local/align.sh b/examples/aishell/s1/local/align.sh new file mode 100755 index 000000000..926cb9397 --- /dev/null +++ b/examples/aishell/s1/local/align.sh @@ -0,0 +1,43 @@ +#! /usr/bin/env bash + +if [ $# != 2 ];then + echo "usage: ${0} config_path ckpt_path_prefix" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +device=gpu +if [ ngpu == 0 ];then + device=cpu +fi +config_path=$1 +ckpt_prefix=$2 + +ckpt_name=$(basename ${ckpt_prefxi}) + +mkdir -p exp + + + +batch_size=1 +output_dir=${ckpt_prefix} +mkdir -p ${output_dir} + +# align dump in `result_file` +# .tier, .TextGrid dump in `dir of result_file` +python3 -u ${BIN_DIR}/alignment.py \ +--device ${device} \ +--nproc 1 \ +--config ${config_path} \ +--result_file ${output_dir}/${type}.align \ +--checkpoint_path ${ckpt_prefix} \ +--opts decoding.batch_size ${batch_size} + +if [ $? -ne 0 ]; then + echo "Failed in ctc alignment!" + exit 1 +fi + +exit 0 diff --git a/examples/aishell/s1/run.sh b/examples/aishell/s1/run.sh index 4cf09553b..562cfa04d 100644 --- a/examples/aishell/s1/run.sh +++ b/examples/aishell/s1/run.sh @@ -30,10 +30,15 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # test ckpt avg_n - CUDA_VISIBLE_DEVICES=4 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 + CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # ctc alignment of test data + CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 +fi + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then # export ckpt avg_n CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit fi diff --git a/examples/librispeech/s0/conf/deepspeech2.yaml b/examples/librispeech/s0/conf/deepspeech2.yaml index b419cbe26..acee94c3e 100644 --- a/examples/librispeech/s0/conf/deepspeech2.yaml +++ b/examples/librispeech/s0/conf/deepspeech2.yaml @@ -48,6 +48,9 @@ training: weight_decay: 1e-06 global_grad_clip: 5.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: batch_size: 128 diff --git a/examples/librispeech/s1/conf/chunk_confermer.yaml b/examples/librispeech/s1/conf/chunk_confermer.yaml index ef08daa84..5af689594 100644 --- a/examples/librispeech/s1/conf/chunk_confermer.yaml +++ b/examples/librispeech/s1/conf/chunk_confermer.yaml @@ -93,6 +93,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/librispeech/s1/conf/chunk_transformer.yaml b/examples/librispeech/s1/conf/chunk_transformer.yaml index 5ec2ad126..f782a0373 100644 --- a/examples/librispeech/s1/conf/chunk_transformer.yaml +++ b/examples/librispeech/s1/conf/chunk_transformer.yaml @@ -86,6 +86,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/librispeech/s1/conf/conformer.yaml b/examples/librispeech/s1/conf/conformer.yaml index cce31b163..955b6108b 100644 --- a/examples/librispeech/s1/conf/conformer.yaml +++ b/examples/librispeech/s1/conf/conformer.yaml @@ -89,6 +89,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/librispeech/s1/conf/transformer.yaml b/examples/librispeech/s1/conf/transformer.yaml index 8ea494772..8a769dca4 100644 --- a/examples/librispeech/s1/conf/transformer.yaml +++ b/examples/librispeech/s1/conf/transformer.yaml @@ -84,6 +84,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/librispeech/s1/local/align.sh b/examples/librispeech/s1/local/align.sh new file mode 100755 index 000000000..926cb9397 --- /dev/null +++ b/examples/librispeech/s1/local/align.sh @@ -0,0 +1,43 @@ +#! /usr/bin/env bash + +if [ $# != 2 ];then + echo "usage: ${0} config_path ckpt_path_prefix" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +device=gpu +if [ ngpu == 0 ];then + device=cpu +fi +config_path=$1 +ckpt_prefix=$2 + +ckpt_name=$(basename ${ckpt_prefxi}) + +mkdir -p exp + + + +batch_size=1 +output_dir=${ckpt_prefix} +mkdir -p ${output_dir} + +# align dump in `result_file` +# .tier, .TextGrid dump in `dir of result_file` +python3 -u ${BIN_DIR}/alignment.py \ +--device ${device} \ +--nproc 1 \ +--config ${config_path} \ +--result_file ${output_dir}/${type}.align \ +--checkpoint_path ${ckpt_prefix} \ +--opts decoding.batch_size ${batch_size} + +if [ $? -ne 0 ]; then + echo "Failed in ctc alignment!" + exit 1 +fi + +exit 0 diff --git a/examples/librispeech/s1/run.sh b/examples/librispeech/s1/run.sh index 65194d902..b81e8dcfd 100755 --- a/examples/librispeech/s1/run.sh +++ b/examples/librispeech/s1/run.sh @@ -33,6 +33,11 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # ctc alignment of test data + CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 +fi + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then # export ckpt avg_n CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit fi diff --git a/examples/tiny/s0/conf/deepspeech2.yaml b/examples/tiny/s0/conf/deepspeech2.yaml index 6737d1b75..ea433f341 100644 --- a/examples/tiny/s0/conf/deepspeech2.yaml +++ b/examples/tiny/s0/conf/deepspeech2.yaml @@ -43,12 +43,16 @@ model: share_rnn_weights: True training: - n_epoch: 24 + n_epoch: 10 lr: 1e-5 lr_decay: 1.0 weight_decay: 1e-06 global_grad_clip: 5.0 log_interval: 1 + checkpoint: + kbest_n: 3 + latest_n: 2 + decoding: batch_size: 128 diff --git a/examples/tiny/s1/conf/chunk_confermer.yaml b/examples/tiny/s1/conf/chunk_confermer.yaml index 790066264..606300bdf 100644 --- a/examples/tiny/s1/conf/chunk_confermer.yaml +++ b/examples/tiny/s1/conf/chunk_confermer.yaml @@ -91,6 +91,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 1 + checkpoint: + kbest_n: 10 + latest_n: 1 decoding: diff --git a/examples/tiny/s1/conf/chunk_transformer.yaml b/examples/tiny/s1/conf/chunk_transformer.yaml index aa2b145a6..72d368485 100644 --- a/examples/tiny/s1/conf/chunk_transformer.yaml +++ b/examples/tiny/s1/conf/chunk_transformer.yaml @@ -84,6 +84,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 1 + checkpoint: + kbest_n: 10 + latest_n: 1 decoding: diff --git a/examples/tiny/s1/conf/conformer.yaml b/examples/tiny/s1/conf/conformer.yaml index 3813daa04..a6f730501 100644 --- a/examples/tiny/s1/conf/conformer.yaml +++ b/examples/tiny/s1/conf/conformer.yaml @@ -87,6 +87,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 1 + checkpoint: + kbest_n: 10 + latest_n: 1 decoding: diff --git a/examples/tiny/s1/conf/transformer.yaml b/examples/tiny/s1/conf/transformer.yaml index 250995faa..71cbdde7f 100644 --- a/examples/tiny/s1/conf/transformer.yaml +++ b/examples/tiny/s1/conf/transformer.yaml @@ -84,6 +84,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 1 + checkpoint: + kbest_n: 10 + latest_n: 1 decoding: diff --git a/examples/tiny/s1/local/align.sh b/examples/tiny/s1/local/align.sh new file mode 100755 index 000000000..926cb9397 --- /dev/null +++ b/examples/tiny/s1/local/align.sh @@ -0,0 +1,43 @@ +#! /usr/bin/env bash + +if [ $# != 2 ];then + echo "usage: ${0} config_path ckpt_path_prefix" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +device=gpu +if [ ngpu == 0 ];then + device=cpu +fi +config_path=$1 +ckpt_prefix=$2 + +ckpt_name=$(basename ${ckpt_prefxi}) + +mkdir -p exp + + + +batch_size=1 +output_dir=${ckpt_prefix} +mkdir -p ${output_dir} + +# align dump in `result_file` +# .tier, .TextGrid dump in `dir of result_file` +python3 -u ${BIN_DIR}/alignment.py \ +--device ${device} \ +--nproc 1 \ +--config ${config_path} \ +--result_file ${output_dir}/${type}.align \ +--checkpoint_path ${ckpt_prefix} \ +--opts decoding.batch_size ${batch_size} + +if [ $? -ne 0 ]; then + echo "Failed in ctc alignment!" + exit 1 +fi + +exit 0 diff --git a/examples/tiny/s1/run.sh b/examples/tiny/s1/run.sh index b148869b7..41f845b05 100755 --- a/examples/tiny/s1/run.sh +++ b/examples/tiny/s1/run.sh @@ -34,6 +34,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # ctc alignment of test data + CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 +fi + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then # export ckpt avg_n - CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit + CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit fi + diff --git a/tools/Makefile b/tools/Makefile index ae68960d5..888a1ac40 100644 --- a/tools/Makefile +++ b/tools/Makefile @@ -38,4 +38,4 @@ soxbindings.done: mfa.done: test -d montreal-forced-aligner || wget https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner/releases/download/v1.0.1/montreal-forced-aligner_linux.tar.gz tar xvf montreal-forced-aligner_linux.tar.gz - touch mfa.done \ No newline at end of file + touch mfa.done