diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 308569cd7..dd62f537e 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -34,9 +34,12 @@ from deepspeech.models.u2 import U2Model from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog from deepspeech.training.scheduler import WarmupLR from deepspeech.training.trainer import Trainer +from deepspeech.utils import ctc_utils from deepspeech.utils import error_rate from deepspeech.utils import layer_tools from deepspeech.utils import mp_tools +from deepspeech.utils import text_grid +from deepspeech.utils import utility from deepspeech.utils.log import Log logger = Log(__name__).getlog() @@ -278,7 +281,15 @@ class U2Trainer(Trainer): shuffle=False, drop_last=False, collate_fn=SpeechCollator.from_config(config)) - logger.info("Setup train/valid/test Dataloader!") + # return text token id + config.collator.keep_transcription_text = False + self.align_loader = DataLoader( + test_dataset, + batch_size=config.decoding.batch_size, + shuffle=False, + drop_last=False, + collate_fn=SpeechCollator.from_config(config)) + logger.info("Setup train/valid/test/align Dataloader!") def setup_model(self): config = self.config @@ -353,7 +364,7 @@ class U2Tester(U2Trainer): decoding_chunk_size=-1, # decoding chunk size. Defaults to -1. # <0: for decoding, use full chunk. # >0: for decoding, use fixed chunk size as set. - # 0: used for training, it's prohibited here. + # 0: used for training, it's prohibited here. num_decoding_left_chunks=-1, # number of left chunks for decoding. Defaults to -1. simulate_streaming=False, # simulate streaming inference. Defaults to False. )) @@ -498,6 +509,73 @@ class U2Tester(U2Trainer): except KeyboardInterrupt: sys.exit(-1) + @paddle.no_grad() + def align(self): + if self.config.decoding.batch_size > 1: + logger.fatal('alignment mode must be running with batch_size == 1') + sys.exit(1) + + # xxx.align + assert self.args.result_file and self.args.result_file.endswith( + '.align') + + self.model.eval() + logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}") + + stride_ms = self.align_loader.collate_fn.stride_ms + token_dict = self.align_loader.collate_fn.vocab_list + with open(self.args.result_file, 'w') as fout: + # one example in batch + for i, batch in enumerate(self.align_loader): + key, feat, feats_length, target, target_length = batch + + # 1. Encoder + encoder_out, encoder_mask = self.model._forward_encoder( + feat, feats_length) # (B, maxlen, encoder_dim) + maxlen = encoder_out.size(1) + ctc_probs = self.model.ctc.log_softmax( + encoder_out) # (1, maxlen, vocab_size) + + # 2. alignment + ctc_probs = ctc_probs.squeeze(0) + target = target.squeeze(0) + alignment = ctc_utils.forced_align(ctc_probs, target) + logger.info("align ids", key[0], alignment) + fout.write('{} {}\n'.format(key[0], alignment)) + + # 3. gen praat + # segment alignment + align_segs = text_grid.segment_alignment(alignment) + logger.info("align tokens", key[0], align_segs) + # IntervalTier, List["start end token\n"] + subsample = utility.get_subsample(self.config) + tierformat = text_grid.align_to_tierformat( + align_segs, subsample, token_dict) + # write tier + align_output_path = os.path.join( + os.path.dirname(self.args.result_file), "align") + tier_path = os.path.join(align_output_path, key[0] + ".tier") + with open(tier_path, 'w') as f: + f.writelines(tierformat) + # write textgrid + textgrid_path = os.path.join(align_output_path, + key[0] + ".TextGrid") + second_per_frame = 1. / (1000. / + stride_ms) # 25ms window, 10ms stride + second_per_example = ( + len(alignment) + 1) * subsample * second_per_frame + text_grid.generate_textgrid( + maxtime=second_per_example, + intervals=tierformat, + output=textgrid_path) + + def run_align(self): + self.resume_or_scratch() + try: + self.align() + except KeyboardInterrupt: + sys.exit(-1) + def load_inferspec(self): """infer model and input spec. @@ -511,10 +589,9 @@ class U2Tester(U2Trainer): self.args.checkpoint_path) feat_dim = self.test_loader.collate_fn.feature_size input_spec = [ - paddle.static.InputSpec( - shape=[None, feat_dim, None], - dtype='float32'), # audio, [B,D,T] - paddle.static.InputSpec(shape=[None], + paddle.static.InputSpec(shape=[1, None, feat_dim], + dtype='float32'), # audio, [B,T,D] + paddle.static.InputSpec(shape=[1], dtype='int64'), # audio_length, [B] ] return infer_model, input_spec diff --git a/deepspeech/io/collator.py b/deepspeech/io/collator.py index 7510dee04..80b62e886 100644 --- a/deepspeech/io/collator.py +++ b/deepspeech/io/collator.py @@ -156,8 +156,8 @@ class SpeechCollator(): random_seed (int, optional): for random generator. Defaults to 0. keep_transcription_text (bool, optional): True, when not in training mode, will not do tokenizer; Defaults to False. if ``keep_transcription_text`` is False, text is token ids else is raw string. - - Do augmentations + + Do augmentations Padding audio features with zeros to make them have the same shape (or a user-defined shape) within one batch. """ diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index 23ae3423d..6b266bdb4 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -599,26 +599,26 @@ class U2BaseModel(nn.Module): best_index = i return hyps[best_index][0] - @jit.export + #@jit.export def subsampling_rate(self) -> int: """ Export interface for c++ call, return subsampling_rate of the model """ return self.encoder.embed.subsampling_rate - @jit.export + #@jit.export def right_context(self) -> int: """ Export interface for c++ call, return right_context of the model """ return self.encoder.embed.right_context - @jit.export + #@jit.export def sos_symbol(self) -> int: """ Export interface for c++ call, return sos symbol id of the model """ return self.sos - @jit.export + #@jit.export def eos_symbol(self) -> int: """ Export interface for c++ call, return eos symbol id of the model """ @@ -654,12 +654,14 @@ class U2BaseModel(nn.Module): xs, offset, required_cache_size, subsampling_cache, elayers_output_cache, conformer_cnn_cache) - @jit.export + # @jit.export([ + # paddle.static.InputSpec(shape=[1, None, feat_dim],dtype='float32'), # audio feat, [B,T,D] + # ]) def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor: """ Export interface for c++ call, apply linear transform and log softmax before ctc Args: - xs (paddle.Tensor): encoder output + xs (paddle.Tensor): encoder output, (B, T, D) Returns: paddle.Tensor: activation before ctc """ @@ -894,7 +896,7 @@ class U2Model(U2BaseModel): model = cls.from_config(config) if checkpoint_path: - infos = checkpoint.load_parameters( + infos = checkpoint.Checkpoint().load_parameters( model, checkpoint_path=checkpoint_path) logger.info(f"checkpoint info: {infos}") layer_tools.summary(model) diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 56de32617..5ebba1a98 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -18,8 +18,8 @@ import paddle from paddle import distributed as dist from tensorboardX import SummaryWriter -from deepspeech.utils import checkpoint from deepspeech.utils import mp_tools +from deepspeech.utils.checkpoint import Checkpoint from deepspeech.utils.log import Log __all__ = ["Trainer"] @@ -139,9 +139,9 @@ class Trainer(): "epoch": self.epoch, "lr": self.optimizer.get_lr() }) - checkpoint.save_parameters(self.checkpoint_dir, self.iteration - if tag is None else tag, self.model, - self.optimizer, infos) + self.checkpoint.add_checkpoint(self.checkpoint_dir, self.iteration + if tag is None else tag, self.model, + self.optimizer, infos) def resume_or_scratch(self): """Resume from latest checkpoint at checkpoints in the output @@ -151,7 +151,7 @@ class Trainer(): resume training. """ scratch = None - infos = checkpoint.load_parameters( + infos = self.checkpoint.load_latest_parameters( self.model, self.optimizer, checkpoint_dir=self.checkpoint_dir, @@ -180,7 +180,7 @@ class Trainer(): from_scratch = self.resume_or_scratch() if from_scratch: # save init model, i.e. 0 epoch - self.save(tag='init') + self.save(tag='init', infos=None) self.lr_scheduler.step(self.iteration) if self.parallel: @@ -263,6 +263,10 @@ class Trainer(): self.checkpoint_dir = checkpoint_dir + self.checkpoint = Checkpoint( + kbest_n=self.config.training.checkpoint.kbest_n, + latest_n=self.config.training.checkpoint.latest_n) + @mp_tools.rank_zero_only def destory(self): """Close visualizer to avoid hanging after training""" diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index 8ede6b8fd..a59f8be79 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -11,9 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import glob import json import os import re +from pathlib import Path +from typing import Text from typing import Union import paddle @@ -25,128 +28,271 @@ from deepspeech.utils.log import Log logger = Log(__name__).getlog() -__all__ = ["load_parameters", "save_parameters"] - - -def _load_latest_checkpoint(checkpoint_dir: str) -> int: - """Get the iteration number corresponding to the latest saved checkpoint. - Args: - checkpoint_dir (str): the directory where checkpoint is saved. - Returns: - int: the latest iteration number. -1 for no checkpoint to load. - """ - checkpoint_record = os.path.join(checkpoint_dir, "checkpoint") - if not os.path.isfile(checkpoint_record): - return -1 - - # Fetch the latest checkpoint index. - with open(checkpoint_record, "rt") as handle: - latest_checkpoint = handle.readlines()[-1].strip() - iteration = int(latest_checkpoint.split(":")[-1]) - return iteration - - -def _save_record(checkpoint_dir: str, iteration: int): - """Save the iteration number of the latest model to be checkpoint record. - Args: - checkpoint_dir (str): the directory where checkpoint is saved. - iteration (int): the latest iteration number. - Returns: - None - """ - checkpoint_record = os.path.join(checkpoint_dir, "checkpoint") - # Update the latest checkpoint index. - with open(checkpoint_record, "a+") as handle: - handle.write("model_checkpoint_path:{}\n".format(iteration)) - - -def load_parameters(model, - optimizer=None, - checkpoint_dir=None, - checkpoint_path=None): - """Load a specific model checkpoint from disk. - Args: - model (Layer): model to load parameters. - optimizer (Optimizer, optional): optimizer to load states if needed. - Defaults to None. - checkpoint_dir (str, optional): the directory where checkpoint is saved. - checkpoint_path (str, optional): if specified, load the checkpoint - stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will - be ignored. Defaults to None. - Returns: - configs (dict): epoch or step, lr and other meta info should be saved. - """ - configs = {} - - if checkpoint_path is not None: - tag = os.path.basename(checkpoint_path).split(":")[-1] - elif checkpoint_dir is not None: - iteration = _load_latest_checkpoint(checkpoint_dir) - if iteration == -1: - return configs - checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration)) - else: - raise ValueError( - "At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!" - ) - - rank = dist.get_rank() - - params_path = checkpoint_path + ".pdparams" - model_dict = paddle.load(params_path) - model.set_state_dict(model_dict) - logger.info("Rank {}: loaded model from {}".format(rank, params_path)) - - optimizer_path = checkpoint_path + ".pdopt" - if optimizer and os.path.isfile(optimizer_path): - optimizer_dict = paddle.load(optimizer_path) - optimizer.set_state_dict(optimizer_dict) - logger.info("Rank {}: loaded optimizer state from {}".format( - rank, optimizer_path)) - - info_path = re.sub('.pdparams$', '.json', params_path) - if os.path.exists(info_path): - with open(info_path, 'r') as fin: - configs = json.load(fin) - return configs - - -@mp_tools.rank_zero_only -def save_parameters(checkpoint_dir: str, - tag_or_iteration: Union[int, str], - model: paddle.nn.Layer, - optimizer: Optimizer=None, - infos: dict=None): - """Checkpoint the latest trained model parameters. - Args: - checkpoint_dir (str): the directory where checkpoint is saved. - tag_or_iteration (int or str): the latest iteration(step or epoch) number. - model (Layer): model to be checkpointed. - optimizer (Optimizer, optional): optimizer to be checkpointed. - Defaults to None. - infos (dict or None): any info you want to save. - Returns: - None - """ - checkpoint_path = os.path.join(checkpoint_dir, - "{}".format(tag_or_iteration)) - - model_dict = model.state_dict() - params_path = checkpoint_path + ".pdparams" - paddle.save(model_dict, params_path) - logger.info("Saved model to {}".format(params_path)) - - if optimizer: - opt_dict = optimizer.state_dict() +__all__ = ["Checkpoint"] + + +class Checkpoint(): + def __init__(self, kbest_n: int=5, latest_n: int=1): + self.best_records: Mapping[Path, float] = {} + self.latest_records = [] + self.kbest_n = kbest_n + self.latest_n = latest_n + self._save_all = (kbest_n == -1) + + def add_checkpoint(self, + checkpoint_dir, + tag_or_iteration: Union[int, Text], + model: paddle.nn.Layer, + optimizer: Optimizer=None, + infos: dict=None, + metric_type="val_loss"): + """Save checkpoint in best_n and latest_n. + + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + tag_or_iteration (int or str): the latest iteration(step or epoch) number or tag. + model (Layer): model to be checkpointed. + optimizer (Optimizer, optional): optimizer to be checkpointed. + infos (dict or None)): any info you want to save. + metric_type (str, optional): metric type. Defaults to "val_loss". + """ + if (metric_type not in infos.keys()): + self._save_parameters(checkpoint_dir, tag_or_iteration, model, + optimizer, infos) + return + + #save best + if self._should_save_best(infos[metric_type]): + self._save_best_checkpoint_and_update( + infos[metric_type], checkpoint_dir, tag_or_iteration, model, + optimizer, infos) + #save latest + self._save_latest_checkpoint_and_update( + checkpoint_dir, tag_or_iteration, model, optimizer, infos) + + if isinstance(tag_or_iteration, int): + self._save_checkpoint_record(checkpoint_dir, tag_or_iteration) + + def load_parameters(self, + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None, + record_file="checkpoint_latest"): + """Load a last model checkpoint from disk. + Args: + model (Layer): model to load parameters. + optimizer (Optimizer, optional): optimizer to load states if needed. + Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will + be ignored. Defaults to None. + record_file "checkpoint_latest" or "checkpoint_best" + Returns: + configs (dict): epoch or step, lr and other meta info should be saved. + """ + configs = {} + + if checkpoint_path is not None: + pass + elif checkpoint_dir is not None and record_file is not None: + # load checkpint from record file + checkpoint_record = os.path.join(checkpoint_dir, record_file) + iteration = self._load_checkpoint_idx(checkpoint_record) + if iteration == -1: + return configs + checkpoint_path = os.path.join(checkpoint_dir, + "{}".format(iteration)) + else: + raise ValueError( + "At least one of 'checkpoint_path' or 'checkpoint_dir' should be specified!" + ) + + rank = dist.get_rank() + + params_path = checkpoint_path + ".pdparams" + model_dict = paddle.load(params_path) + model.set_state_dict(model_dict) + logger.info("Rank {}: loaded model from {}".format(rank, params_path)) + optimizer_path = checkpoint_path + ".pdopt" - paddle.save(opt_dict, optimizer_path) - logger.info("Saved optimzier state to {}".format(optimizer_path)) + if optimizer and os.path.isfile(optimizer_path): + optimizer_dict = paddle.load(optimizer_path) + optimizer.set_state_dict(optimizer_dict) + logger.info("Rank {}: loaded optimizer state from {}".format( + rank, optimizer_path)) + + info_path = re.sub('.pdparams$', '.json', params_path) + if os.path.exists(info_path): + with open(info_path, 'r') as fin: + configs = json.load(fin) + return configs + + def load_latest_parameters(self, + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): + """Load a last model checkpoint from disk. + Args: + model (Layer): model to load parameters. + optimizer (Optimizer, optional): optimizer to load states if needed. + Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will + be ignored. Defaults to None. + Returns: + configs (dict): epoch or step, lr and other meta info should be saved. + """ + return self.load_parameters(model, optimizer, checkpoint_dir, + checkpoint_path, "checkpoint_latest") + + def load_best_parameters(self, + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): + """Load a last model checkpoint from disk. + Args: + model (Layer): model to load parameters. + optimizer (Optimizer, optional): optimizer to load states if needed. + Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will + be ignored. Defaults to None. + Returns: + configs (dict): epoch or step, lr and other meta info should be saved. + """ + return self.load_parameters(model, optimizer, checkpoint_dir, + checkpoint_path, "checkpoint_best") + + def _should_save_best(self, metric: float) -> bool: + if not self._best_full(): + return True + + # already full + worst_record_path = max(self.best_records, key=self.best_records.get) + # worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0] + worst_metric = self.best_records[worst_record_path] + return metric < worst_metric + + def _best_full(self): + return (not self._save_all) and len(self.best_records) == self.kbest_n + + def _latest_full(self): + return len(self.latest_records) == self.latest_n + + def _save_best_checkpoint_and_update(self, metric, checkpoint_dir, + tag_or_iteration, model, optimizer, + infos): + # remove the worst + if self._best_full(): + worst_record_path = max(self.best_records, + key=self.best_records.get) + self.best_records.pop(worst_record_path) + if (worst_record_path not in self.latest_records): + logger.info( + "remove the worst checkpoint: {}".format(worst_record_path)) + self._del_checkpoint(checkpoint_dir, worst_record_path) + + # add the new one + self._save_parameters(checkpoint_dir, tag_or_iteration, model, + optimizer, infos) + self.best_records[tag_or_iteration] = metric + + def _save_latest_checkpoint_and_update( + self, checkpoint_dir, tag_or_iteration, model, optimizer, infos): + # remove the old + if self._latest_full(): + to_del_fn = self.latest_records.pop(0) + if (to_del_fn not in self.best_records.keys()): + logger.info( + "remove the latest checkpoint: {}".format(to_del_fn)) + self._del_checkpoint(checkpoint_dir, to_del_fn) + self.latest_records.append(tag_or_iteration) + + self._save_parameters(checkpoint_dir, tag_or_iteration, model, + optimizer, infos) + + def _del_checkpoint(self, checkpoint_dir, tag_or_iteration): + checkpoint_path = os.path.join(checkpoint_dir, + "{}".format(tag_or_iteration)) + for filename in glob.glob(checkpoint_path + ".*"): + os.remove(filename) + logger.info("delete file: {}".format(filename)) + + def _load_checkpoint_idx(self, checkpoint_record: str) -> int: + """Get the iteration number corresponding to the latest saved checkpoint. + Args: + checkpoint_path (str): the saved path of checkpoint. + Returns: + int: the latest iteration number. -1 for no checkpoint to load. + """ + if not os.path.isfile(checkpoint_record): + return -1 + + # Fetch the latest checkpoint index. + with open(checkpoint_record, "rt") as handle: + latest_checkpoint = handle.readlines()[-1].strip() + iteration = int(latest_checkpoint.split(":")[-1]) + return iteration + + def _save_checkpoint_record(self, checkpoint_dir: str, iteration: int): + """Save the iteration number of the latest model to be checkpoint record. + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + iteration (int): the latest iteration number. + Returns: + None + """ + checkpoint_record_latest = os.path.join(checkpoint_dir, + "checkpoint_latest") + checkpoint_record_best = os.path.join(checkpoint_dir, "checkpoint_best") + + with open(checkpoint_record_best, "w") as handle: + for i in self.best_records.keys(): + handle.write("model_checkpoint_path:{}\n".format(i)) + with open(checkpoint_record_latest, "w") as handle: + for i in self.latest_records: + handle.write("model_checkpoint_path:{}\n".format(i)) + + @mp_tools.rank_zero_only + def _save_parameters(self, + checkpoint_dir: str, + tag_or_iteration: Union[int, str], + model: paddle.nn.Layer, + optimizer: Optimizer=None, + infos: dict=None): + """Checkpoint the latest trained model parameters. + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + tag_or_iteration (int or str): the latest iteration(step or epoch) number. + model (Layer): model to be checkpointed. + optimizer (Optimizer, optional): optimizer to be checkpointed. + Defaults to None. + infos (dict or None): any info you want to save. + Returns: + None + """ + checkpoint_path = os.path.join(checkpoint_dir, + "{}".format(tag_or_iteration)) + + model_dict = model.state_dict() + params_path = checkpoint_path + ".pdparams" + paddle.save(model_dict, params_path) + logger.info("Saved model to {}".format(params_path)) - info_path = re.sub('.pdparams$', '.json', params_path) - infos = {} if infos is None else infos - with open(info_path, 'w') as fout: - data = json.dumps(infos) - fout.write(data) + if optimizer: + opt_dict = optimizer.state_dict() + optimizer_path = checkpoint_path + ".pdopt" + paddle.save(opt_dict, optimizer_path) + logger.info("Saved optimzier state to {}".format(optimizer_path)) - if isinstance(tag_or_iteration, int): - _save_record(checkpoint_dir, tag_or_iteration) + info_path = re.sub('.pdparams$', '.json', params_path) + infos = {} if infos is None else infos + with open(info_path, 'w') as fout: + data = json.dumps(infos) + fout.write(data) diff --git a/deepspeech/utils/ctc_utils.py b/deepspeech/utils/ctc_utils.py index 73669fea6..09543d48d 100644 --- a/deepspeech/utils/ctc_utils.py +++ b/deepspeech/utils/ctc_utils.py @@ -38,21 +38,23 @@ def remove_duplicates_and_blank(hyp: List[int], blank_id=0) -> List[int]: new_hyp: List[int] = [] cur = 0 while cur < len(hyp): + # add non-blank into new_hyp if hyp[cur] != blank_id: new_hyp.append(hyp[cur]) + # skip repeat label prev = cur while cur < len(hyp) and hyp[cur] == hyp[prev]: cur += 1 return new_hyp -def insert_blank(label: np.ndarray, blank_id: int=0): +def insert_blank(label: np.ndarray, blank_id: int=0) -> np.ndarray: """Insert blank token between every two label token. "abcdefg" -> "-a-b-c-d-e-f-g-" Args: - label ([np.ndarray]): label ids, (L). + label ([np.ndarray]): label ids, List[int], (L). blank_id (int, optional): blank id. Defaults to 0. Returns: @@ -61,13 +63,13 @@ def insert_blank(label: np.ndarray, blank_id: int=0): label = np.expand_dims(label, 1) #[L, 1] blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id label = np.concatenate([blanks, label], axis=1) #[L, 2] - label = label.reshape(-1) #[2L] - label = np.append(label, label[0]) #[2L + 1] + label = label.reshape(-1) #[2L], -l-l-l + label = np.append(label, label[0]) #[2L + 1], -l-l-l- return label def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, - blank_id=0) -> list: + blank_id=0) -> List[int]: """ctc forced alignment. https://distill.pub/2017/ctc/ @@ -77,23 +79,25 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, y (paddle.Tensor): label id sequence tensor, 1d tensor (L) blank_id (int): blank symbol index Returns: - paddle.Tensor: best alignment result, (T). + List[int]: best alignment result, (T). """ - y_insert_blank = insert_blank(y, blank_id) + y_insert_blank = insert_blank(y, blank_id) #(2L+1) log_alpha = paddle.zeros( (ctc_probs.size(0), len(y_insert_blank))) #(T, 2L+1) log_alpha = log_alpha - float('inf') # log of zero + # TODO(Hui Zhang): zeros not support paddle.int16 state_path = (paddle.zeros( - (ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int16) - 1 - ) # state path + (ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int32) - 1 + ) # state path, Tuple((T, 2L+1)) # init start state - log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] # Sb - log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] # Snb + # TODO(Hui Zhang): VarBase.__getitem__() not support np.int64 + log_alpha[0, 0] = ctc_probs[0][int(y_insert_blank[0])] # State-b, Sb + log_alpha[0, 1] = ctc_probs[0][int(y_insert_blank[1])] # State-nb, Snb - for t in range(1, ctc_probs.size(0)): - for s in range(len(y_insert_blank)): + for t in range(1, ctc_probs.size(0)): # T + for s in range(len(y_insert_blank)): # 2L+1 if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[ s] == y_insert_blank[s - 2]: candidates = paddle.to_tensor( @@ -106,11 +110,13 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor, log_alpha[t - 1, s - 2], ]) prev_state = [s, s - 1, s - 2] - log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][ - y_insert_blank[s]] + # TODO(Hui Zhang): VarBase.__getitem__() not support np.int64 + log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][int( + y_insert_blank[s])] state_path[t, s] = prev_state[paddle.argmax(candidates)] - state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int16) + # TODO(Hui Zhang): zeros not support paddle.int16 + state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int32) candidates = paddle.to_tensor([ log_alpha[-1, len(y_insert_blank) - 1], # Sb diff --git a/deepspeech/utils/text_grid.py b/deepspeech/utils/text_grid.py new file mode 100644 index 000000000..3af58c9ba --- /dev/null +++ b/deepspeech/utils/text_grid.py @@ -0,0 +1,127 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict +from typing import List +from typing import Text + +import textgrid + + +def segment_alignment(alignment: List[int], blank_id=0) -> List[List[int]]: + """segment ctc alignment ids by continuous blank and repeat label. + + Args: + alignment (List[int]): ctc alignment id sequence. + e.g. [0, 0, 0, 1, 1, 1, 2, 0, 0, 3] + blank_id (int, optional): blank id. Defaults to 0. + + Returns: + List[List[int]]: token align, segment aligment id sequence. + e.g. [[0, 0, 0, 1, 1, 1], [2], [0, 0, 3]] + """ + # convert alignment to a praat format, which is a doing phonetics + # by computer and helps analyzing alignment + align_segs = [] + # get frames level duration for each token + start = 0 + end = 0 + while end < len(alignment): + while end < len(alignment) and alignment[end] == blank_id: # blank + end += 1 + if end == len(alignment): + align_segs[-1].extend(alignment[start:]) + break + end += 1 + while end < len(alignment) and alignment[end - 1] == alignment[ + end]: # repeat label + end += 1 + align_segs.append(alignment[start:end]) + start = end + return align_segs + + +def align_to_tierformat(align_segs: List[List[int]], + subsample: int, + token_dict: Dict[int, Text], + blank_id=0) -> List[Text]: + """Generate textgrid.Interval format from alignment segmentations. + + Args: + align_segs (List[List[int]]): segmented ctc alignment ids. + subsample (int): 25ms frame_length, 10ms hop_length, 1/subsample + token_dict (Dict[int, Text]): int -> str map. + + Returns: + List[Text]: list of textgrid.Interval text, str(start, end, text). + """ + hop_length = 10 # ms + second_ms = 1000 # ms + frame_per_second = second_ms / hop_length # 25ms frame_length, 10ms hop_length + second_per_frame = 1.0 / frame_per_second + + begin = 0 + duration = 0 + tierformat = [] + + for idx, tokens in enumerate(align_segs): + token_len = len(tokens) + token = tokens[-1] + # time duration in second + duration = token_len * subsample * second_per_frame + if idx < len(align_segs) - 1: + print(f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}") + tierformat.append( + f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}\n") + else: + for i in tokens: + if i != blank_id: + token = i + break + print(f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}") + tierformat.append( + f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}\n") + begin = begin + duration + + return tierformat + + +def generate_textgrid(maxtime: float, + intervals: List[Text], + output: Text, + name: Text='ali') -> None: + """Create alignment textgrid file. + + Args: + maxtime (float): audio duartion. + intervals (List[Text]): ctc output alignment. e.g. "start-time end-time word" per item. + output (Text): textgrid filepath. + name (Text, optional): tier or layer name. Defaults to 'ali'. + """ + # Download Praat: https://www.fon.hum.uva.nl/praat/ + avg_interval = maxtime / (len(intervals) + 1) + print(f"average second/token: {avg_interval}") + margin = 0.0001 + + tg = textgrid.TextGrid(maxTime=maxtime) + tier = textgrid.IntervalTier(name=name, maxTime=maxtime) + + i = 0 + for dur in intervals: + s, e, text = dur.split() + tier.add(minTime=float(s) + margin, maxTime=float(e), mark=text) + + tg.append(tier) + + tg.write(output) + print("successfully generator textgrid {}.".format(output)) diff --git a/deepspeech/utils/utility.py b/deepspeech/utils/utility.py index 64570026b..a0639e065 100644 --- a/deepspeech/utils/utility.py +++ b/deepspeech/utils/utility.py @@ -79,3 +79,22 @@ def log_add(args: List[int]) -> float: a_max = max(args) lsp = math.log(sum(math.exp(a - a_max) for a in args)) return a_max + lsp + + +def get_subsample(config): + """Subsample rate from config. + + Args: + config (yacs.config.CfgNode): yaml config + + Returns: + int: subsample rate. + """ + input_layer = config["model"]["encoder_conf"]["input_layer"] + assert input_layer in ["conv2d", "conv2d6", "conv2d8"] + if input_layer == "conv2d": + return 4 + elif input_layer == "conv2d6": + return 6 + elif input_layer == "conv2d8": + return 8 diff --git a/examples/aishell/s0/conf/deepspeech2.yaml b/examples/aishell/s0/conf/deepspeech2.yaml index e7a5c6dcf..fea233c7e 100644 --- a/examples/aishell/s0/conf/deepspeech2.yaml +++ b/examples/aishell/s0/conf/deepspeech2.yaml @@ -49,6 +49,9 @@ training: weight_decay: 1e-06 global_grad_clip: 3.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: batch_size: 128 diff --git a/examples/aishell/s1/README.md b/examples/aishell/s1/README.md index 72a03b618..78e759c8f 100644 --- a/examples/aishell/s1/README.md +++ b/examples/aishell/s1/README.md @@ -3,11 +3,12 @@ ## Conformer | Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER | -| --- | --- | --- | --- | --- | --- | --- | --- | -| conformer | 47.06M | conf/conformer.yaml | spec_aug + shift | test | attention | - | 0.059858 | -| conformer | 47.06M | conf/conformer.yaml | spec_aug + shift | test | ctc_greedy_search | - | 0.062311 | -| conformer | 47.06M | conf/conformer.yaml | spec_aug + shift | test | ctc_prefix_beam_search | - | 0.062196 | -| conformer | 47.06M | conf/conformer.yaml | spec_aug + shift | test | attention_rescoring | - | 0.054694 | +| --- | --- | --- | --- | --- | --- | --- | --- | +| conformer | 47.07M | conf/conformer.yaml | spec_aug + shift | test | attention | - | 0.059858 | +| conformer | 47.07M | conf/conformer.yaml | spec_aug + shift | test | ctc_greedy_search | - | 0.062311 | +| conformer | 47.07M | conf/conformer.yaml | spec_aug + shift | test | ctc_prefix_beam_search | - | 0.062196 | +| conformer | 47.07M | conf/conformer.yaml | spec_aug + shift | test | attention_rescoring | - | 0.054694 | + ## Chunk Conformer @@ -21,6 +22,6 @@ ## Transformer -| Model | Config | Augmentation| Test set | Decode method | Loss | WER | -| --- | --- | --- | --- | --- | --- | ---| -| transformer | conf/transformer.yaml | spec_aug + shift | test | attention | - | - | +| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER | +| --- | --- | --- | --- | --- | --- | --- | ---| +| transformer | - | conf/transformer.yaml | spec_aug + shift | test | attention | - | - | diff --git a/examples/aishell/s1/conf/chunk_conformer.yaml b/examples/aishell/s1/conf/chunk_conformer.yaml index 0e5b8699f..3e606788e 100644 --- a/examples/aishell/s1/conf/chunk_conformer.yaml +++ b/examples/aishell/s1/conf/chunk_conformer.yaml @@ -93,6 +93,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/aishell/s1/conf/conformer.yaml b/examples/aishell/s1/conf/conformer.yaml index 116c91927..4b1430c58 100644 --- a/examples/aishell/s1/conf/conformer.yaml +++ b/examples/aishell/s1/conf/conformer.yaml @@ -88,6 +88,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/aishell/s1/local/align.sh b/examples/aishell/s1/local/align.sh new file mode 100755 index 000000000..926cb9397 --- /dev/null +++ b/examples/aishell/s1/local/align.sh @@ -0,0 +1,43 @@ +#! /usr/bin/env bash + +if [ $# != 2 ];then + echo "usage: ${0} config_path ckpt_path_prefix" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +device=gpu +if [ ngpu == 0 ];then + device=cpu +fi +config_path=$1 +ckpt_prefix=$2 + +ckpt_name=$(basename ${ckpt_prefxi}) + +mkdir -p exp + + + +batch_size=1 +output_dir=${ckpt_prefix} +mkdir -p ${output_dir} + +# align dump in `result_file` +# .tier, .TextGrid dump in `dir of result_file` +python3 -u ${BIN_DIR}/alignment.py \ +--device ${device} \ +--nproc 1 \ +--config ${config_path} \ +--result_file ${output_dir}/${type}.align \ +--checkpoint_path ${ckpt_prefix} \ +--opts decoding.batch_size ${batch_size} + +if [ $? -ne 0 ]; then + echo "Failed in ctc alignment!" + exit 1 +fi + +exit 0 diff --git a/examples/aishell/s1/run.sh b/examples/aishell/s1/run.sh index 4cf09553b..65b48a976 100644 --- a/examples/aishell/s1/run.sh +++ b/examples/aishell/s1/run.sh @@ -30,10 +30,15 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then # test ckpt avg_n - CUDA_VISIBLE_DEVICES=4 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 + CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # ctc alignment of test data + CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 +fi + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then # export ckpt avg_n - CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit + CUDA_VISIBLE_DEVICES=0 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit fi diff --git a/examples/dataset/aidatatang_200zh/.gitignore b/examples/dataset/aidatatang_200zh/.gitignore new file mode 100644 index 000000000..fc56525e6 --- /dev/null +++ b/examples/dataset/aidatatang_200zh/.gitignore @@ -0,0 +1,4 @@ +*.tgz +manifest.* +*.meta +aidatatang_200zh/ \ No newline at end of file diff --git a/examples/dataset/aidatatang_200zh/README.md b/examples/dataset/aidatatang_200zh/README.md new file mode 100644 index 000000000..e6f1eefbd --- /dev/null +++ b/examples/dataset/aidatatang_200zh/README.md @@ -0,0 +1,14 @@ +# [Aidatatang_200zh](http://www.openslr.org/62/) + +Aidatatang_200zh is a free Chinese Mandarin speech corpus provided by Beijing DataTang Technology Co., Ltd under Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International Public License. +The contents and the corresponding descriptions of the corpus include: + +* The corpus contains 200 hours of acoustic data, which is mostly mobile recorded data. +* 600 speakers from different accent areas in China are invited to participate in the recording. +* The transcription accuracy for each sentence is larger than 98%. +* Recordings are conducted in a quiet indoor environment. +* The database is divided into training set, validation set, and testing set in a ratio of 7: 1: 2. +* Detail information such as speech data coding and speaker information is preserved in the metadata file. +* Segmented transcripts are also provided. + +The corpus aims to support researchers in speech recognition, machine translation, voiceprint recognition, and other speech-related fields. Therefore, the corpus is totally free for academic use. diff --git a/examples/dataset/aidatatang_200zh/aidatatang_200zh.py b/examples/dataset/aidatatang_200zh/aidatatang_200zh.py new file mode 100644 index 000000000..cc77c3c48 --- /dev/null +++ b/examples/dataset/aidatatang_200zh/aidatatang_200zh.py @@ -0,0 +1,151 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Prepare aidatatang_200zh mandarin dataset + +Download, unpack and create manifest files. +Manifest file is a json-format file with each line containing the +meta data (i.e. audio filepath, transcript and audio duration) +of each audio file in the data set. +""" +import argparse +import codecs +import json +import os + +import soundfile + +from utils.utility import download +from utils.utility import unpack + +DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') + +URL_ROOT = 'http://www.openslr.org/resources/62' +# URL_ROOT = 'https://openslr.magicdatatech.com/resources/62' +DATA_URL = URL_ROOT + '/aidatatang_200zh.tgz' +MD5_DATA = '6e0f4f39cd5f667a7ee53c397c8d0949' + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument( + "--target_dir", + default=DATA_HOME + "/aidatatang_200zh", + type=str, + help="Directory to save the dataset. (default: %(default)s)") +parser.add_argument( + "--manifest_prefix", + default="manifest", + type=str, + help="Filepath prefix for output manifests. (default: %(default)s)") +args = parser.parse_args() + + +def create_manifest(data_dir, manifest_path_prefix): + print("Creating manifest %s ..." % manifest_path_prefix) + json_lines = [] + transcript_path = os.path.join(data_dir, 'transcript', + 'aidatatang_200_zh_transcript.txt') + transcript_dict = {} + for line in codecs.open(transcript_path, 'r', 'utf-8'): + line = line.strip() + if line == '': + continue + audio_id, text = line.split(' ', 1) + # remove withespace, charactor text + text = ''.join(text.split()) + transcript_dict[audio_id] = text + + data_types = ['train', 'dev', 'test'] + for dtype in data_types: + del json_lines[:] + total_sec = 0.0 + total_text = 0.0 + total_num = 0 + + audio_dir = os.path.join(data_dir, 'corpus/', dtype) + for subfolder, _, filelist in sorted(os.walk(audio_dir)): + for fname in filelist: + if not fname.endswith('.wav'): + continue + + audio_path = os.path.abspath(os.path.join(subfolder, fname)) + audio_id = os.path.basename(fname)[:-4] + + audio_data, samplerate = soundfile.read(audio_path) + duration = float(len(audio_data) / samplerate) + text = transcript_dict[audio_id] + json_lines.append( + json.dumps( + { + 'utt': audio_id, + 'feat': audio_path, + 'feat_shape': (duration, ), # second + 'text': text, + }, + ensure_ascii=False)) + + total_sec += duration + total_text += len(text) + total_num += 1 + + manifest_path = manifest_path_prefix + '.' + dtype + with codecs.open(manifest_path, 'w', 'utf-8') as fout: + for line in json_lines: + fout.write(line + '\n') + + with open(dtype + '.meta', 'w') as f: + print(f"{dtype}:", file=f) + print(f"{total_num} utts", file=f) + print(f"{total_sec / (60*60)} h", file=f) + print(f"{total_text} text", file=f) + print(f"{total_text / total_sec} text/sec", file=f) + print(f"{total_sec / total_num} sec/utt", file=f) + + +def prepare_dataset(url, md5sum, target_dir, manifest_path, subset): + """Download, unpack and create manifest file.""" + data_dir = os.path.join(target_dir, subset) + if not os.path.exists(data_dir): + filepath = download(url, md5sum, target_dir) + unpack(filepath, target_dir) + # unpack all audio tar files + audio_dir = os.path.join(data_dir, 'corpus') + for subfolder, dirlist, filelist in sorted(os.walk(audio_dir)): + for sub in dirlist: + print(f"unpack dir {sub}...") + for folder, _, filelist in sorted( + os.walk(os.path.join(subfolder, sub))): + for ftar in filelist: + unpack(os.path.join(folder, ftar), folder, True) + else: + print("Skip downloading and unpacking. Data already exists in %s." % + target_dir) + + create_manifest(data_dir, manifest_path) + + +def main(): + if args.target_dir.startswith('~'): + args.target_dir = os.path.expanduser(args.target_dir) + + prepare_dataset( + url=DATA_URL, + md5sum=MD5_DATA, + target_dir=args.target_dir, + manifest_path=args.manifest_prefix, + subset='aidatatang_200zh') + + print("Data download and manifest prepare done!") + + +if __name__ == '__main__': + main() diff --git a/examples/dataset/aishell/.gitignore b/examples/dataset/aishell/.gitignore index 9c6e517e5..eea6573e1 100644 --- a/examples/dataset/aishell/.gitignore +++ b/examples/dataset/aishell/.gitignore @@ -1 +1,4 @@ data_aishell* +*.meta +manifest.* +*.tgz \ No newline at end of file diff --git a/examples/dataset/aishell/README.md b/examples/dataset/aishell/README.md new file mode 100644 index 000000000..6770cd207 --- /dev/null +++ b/examples/dataset/aishell/README.md @@ -0,0 +1,3 @@ +# [Aishell1](http://www.openslr.org/33/) + +This Open Source Mandarin Speech Corpus, AISHELL-ASR0009-OS1, is 178 hours long. It is a part of AISHELL-ASR0009, of which utterance contains 11 domains, including smart home, autonomous driving, and industrial production. The whole recording was put in quiet indoor environment, using 3 different devices at the same time: high fidelity microphone (44.1kHz, 16-bit,); Android-system mobile phone (16kHz, 16-bit), iOS-system mobile phone (16kHz, 16-bit). Audios in high fidelity were re-sampled to 16kHz to build AISHELL- ASR0009-OS1. 400 speakers from different accent areas in China were invited to participate in the recording. The manual transcription accuracy rate is above 95%, through professional speech annotation and strict quality inspection. The corpus is divided into training, development and testing sets. ( This database is free for academic research, not in the commerce, if without permission. ) diff --git a/examples/dataset/aishell/aishell.py b/examples/dataset/aishell/aishell.py index a0cabe352..5811a401a 100644 --- a/examples/dataset/aishell/aishell.py +++ b/examples/dataset/aishell/aishell.py @@ -31,7 +31,7 @@ from utils.utility import unpack DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') URL_ROOT = 'http://www.openslr.org/resources/33' -URL_ROOT = 'https://openslr.magicdatatech.com/resources/33' +# URL_ROOT = 'https://openslr.magicdatatech.com/resources/33' DATA_URL = URL_ROOT + '/data_aishell.tgz' MD5_DATA = '2f494334227864a8a8fec932999db9d8' @@ -60,18 +60,22 @@ def create_manifest(data_dir, manifest_path_prefix): if line == '': continue audio_id, text = line.split(' ', 1) - # remove withespace + # remove withespace, charactor text text = ''.join(text.split()) transcript_dict[audio_id] = text data_types = ['train', 'dev', 'test'] for dtype in data_types: del json_lines[:] + total_sec = 0.0 + total_text = 0.0 + total_num = 0 + audio_dir = os.path.join(data_dir, 'wav', dtype) for subfolder, _, filelist in sorted(os.walk(audio_dir)): for fname in filelist: - audio_path = os.path.join(subfolder, fname) - audio_id = fname[:-4] + audio_path = os.path.abspath(os.path.join(subfolder, fname)) + audio_id = os.path.basename(fname)[:-4] # if no transcription for audio then skipped if audio_id not in transcript_dict: continue @@ -81,20 +85,30 @@ def create_manifest(data_dir, manifest_path_prefix): json_lines.append( json.dumps( { - 'utt': - os.path.splitext(os.path.basename(audio_path))[0], - 'feat': - audio_path, + 'utt': audio_id, + 'feat': audio_path, 'feat_shape': (duration, ), # second - 'text': - text + 'text': text }, ensure_ascii=False)) + + total_sec += duration + total_text += len(text) + total_num += 1 + manifest_path = manifest_path_prefix + '.' + dtype with codecs.open(manifest_path, 'w', 'utf-8') as fout: for line in json_lines: fout.write(line + '\n') + with open(dtype + '.meta', 'w') as f: + print(f"{dtype}:", file=f) + print(f"{total_num} utts", file=f) + print(f"{total_sec / (60*60)} h", file=f) + print(f"{total_text} text", file=f) + print(f"{total_text / total_sec} text/sec", file=f) + print(f"{total_sec / total_num} sec/utt", file=f) + def prepare_dataset(url, md5sum, target_dir, manifest_path): """Download, unpack and create manifest file.""" @@ -123,6 +137,8 @@ def main(): target_dir=args.target_dir, manifest_path=args.manifest_prefix) + print("Data download and manifest prepare done!") + if __name__ == '__main__': main() diff --git a/examples/dataset/aishell3/README.md b/examples/dataset/aishell3/README.md new file mode 100644 index 000000000..8a29a6d0f --- /dev/null +++ b/examples/dataset/aishell3/README.md @@ -0,0 +1,3 @@ +# [Aishell3](http://www.openslr.org/93/) + +AISHELL-3 is a large-scale and high-fidelity multi-speaker Mandarin speech corpus which could be used to train multi-speaker Text-to-Speech (TTS) systems. The corpus contains roughly **85 hours** of emotion-neutral recordings spoken by 218 native Chinese mandarin speakers and total 88035 utterances. Their auxiliary attributes such as gender, age group and native accents are explicitly marked and provided in the corpus. Accordingly, transcripts in Chinese character-level and pinyin-level are provided along with the recordings. The word & tone transcription accuracy rate is above 98%, through professional speech annotation and strict quality inspection for tone and prosody. ( This database is free for academic research, not in the commerce, if without permission. ) diff --git a/examples/dataset/gigaspeech/.gitignore b/examples/dataset/gigaspeech/.gitignore new file mode 100644 index 000000000..7f78176b7 --- /dev/null +++ b/examples/dataset/gigaspeech/.gitignore @@ -0,0 +1 @@ +GigaSpeech/ diff --git a/examples/dataset/gigaspeech/README.md b/examples/dataset/gigaspeech/README.md new file mode 100644 index 000000000..4a1715cb8 --- /dev/null +++ b/examples/dataset/gigaspeech/README.md @@ -0,0 +1,10 @@ +# [GigaSpeech](https://github.com/SpeechColab/GigaSpeech) + +``` +git clone https://github.com/SpeechColab/GigaSpeech.git + +cd GigaSpeech +utils/gigaspeech_download.sh /disk1/audio_data/gigaspeech +toolkits/kaldi/gigaspeech_data_prep.sh --train-subset XL /disk1/audio_data/gigaspeech ../data +cd .. +``` diff --git a/examples/dataset/gigaspeech/gigaspeech.py b/examples/dataset/gigaspeech/gigaspeech.py new file mode 100644 index 000000000..185a92b8d --- /dev/null +++ b/examples/dataset/gigaspeech/gigaspeech.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/dataset/gigaspeech/run.sh b/examples/dataset/gigaspeech/run.sh new file mode 100755 index 000000000..a1ad8610c --- /dev/null +++ b/examples/dataset/gigaspeech/run.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +set -e + +curdir=$PWD + +test -d GigaSpeech || git clone https://github.com/SpeechColab/GigaSpeech.git + + +pushd GigaSpeech +source env_vars.sh +./utils/download_gigaspeech.sh ${curdir}/ +#toolkits/kaldi/gigaspeech_data_prep.sh --train-subset XL /disk1/audio_data/gigaspeech ../data +popd diff --git a/examples/dataset/librispeech/.gitignore b/examples/dataset/librispeech/.gitignore index dfd5c67b5..465806def 100644 --- a/examples/dataset/librispeech/.gitignore +++ b/examples/dataset/librispeech/.gitignore @@ -5,3 +5,5 @@ test-other train-clean-100 train-clean-360 train-other-500 +*.meta +manifest.* diff --git a/examples/dataset/librispeech/librispeech.py b/examples/dataset/librispeech/librispeech.py index 55012f73c..f549a95f1 100644 --- a/examples/dataset/librispeech/librispeech.py +++ b/examples/dataset/librispeech/librispeech.py @@ -77,6 +77,10 @@ def create_manifest(data_dir, manifest_path): """ print("Creating manifest %s ..." % manifest_path) json_lines = [] + total_sec = 0.0 + total_text = 0.0 + total_num = 0 + for subfolder, _, filelist in sorted(os.walk(data_dir)): text_filelist = [ filename for filename in filelist if filename.endswith('trans.txt') @@ -86,7 +90,9 @@ def create_manifest(data_dir, manifest_path): for line in io.open(text_filepath, encoding="utf8"): segments = line.strip().split() text = ' '.join(segments[1:]).lower() - audio_filepath = os.path.join(subfolder, segments[0] + '.flac') + + audio_filepath = os.path.abspath( + os.path.join(subfolder, segments[0] + '.flac')) audio_data, samplerate = soundfile.read(audio_filepath) duration = float(len(audio_data)) / samplerate json_lines.append( @@ -99,10 +105,24 @@ def create_manifest(data_dir, manifest_path): 'text': text })) + + total_sec += duration + total_text += len(text) + total_num += 1 + with codecs.open(manifest_path, 'w', 'utf-8') as out_file: for line in json_lines: out_file.write(line + '\n') + subset = os.path.splitext(manifest_path)[1] + with open(subset + '.meta', 'w') as f: + print(f"{subset}:", file=f) + print(f"{total_num} utts", file=f) + print(f"{total_sec / (60*60)} h", file=f) + print(f"{total_text} text", file=f) + print(f"{total_text / total_sec} text/sec", file=f) + print(f"{total_sec / total_num} sec/utt", file=f) + def prepare_dataset(url, md5sum, target_dir, manifest_path): """Download, unpack and create summmary manifest file. diff --git a/examples/dataset/magicdata/README.md b/examples/dataset/magicdata/README.md new file mode 100644 index 000000000..083aee97b --- /dev/null +++ b/examples/dataset/magicdata/README.md @@ -0,0 +1,15 @@ +# [MagicData](http://www.openslr.org/68/) + +MAGICDATA Mandarin Chinese Read Speech Corpus was developed by MAGIC DATA Technology Co., Ltd. and freely published for non-commercial use. +The contents and the corresponding descriptions of the corpus include: + +* The corpus contains 755 hours of speech data, which is mostly mobile recorded data. +* 1080 speakers from different accent areas in China are invited to participate in the recording. +* The sentence transcription accuracy is higher than 98%. +* Recordings are conducted in a quiet indoor environment. +* The database is divided into training set, validation set, and testing set in a ratio of 51: 1: 2. +* Detail information such as speech data coding and speaker information is preserved in the metadata file. +* The domain of recording texts is diversified, including interactive Q&A, music search, SNS messages, home command and control, etc. +* Segmented transcripts are also provided. + +The corpus aims to support researchers in speech recognition, machine translation, speaker recognition, and other speech-related fields. Therefore, the corpus is totally free for academic use. diff --git a/examples/dataset/mini_librispeech/.gitignore b/examples/dataset/mini_librispeech/.gitignore index 61f54c966..7fbcfd65d 100644 --- a/examples/dataset/mini_librispeech/.gitignore +++ b/examples/dataset/mini_librispeech/.gitignore @@ -2,3 +2,4 @@ dev-clean/ manifest.dev-clean manifest.train-clean train-clean/ +*.meta diff --git a/examples/dataset/mini_librispeech/mini_librispeech.py b/examples/dataset/mini_librispeech/mini_librispeech.py index f5bc13933..44a6d3671 100644 --- a/examples/dataset/mini_librispeech/mini_librispeech.py +++ b/examples/dataset/mini_librispeech/mini_librispeech.py @@ -58,6 +58,10 @@ def create_manifest(data_dir, manifest_path): """ print("Creating manifest %s ..." % manifest_path) json_lines = [] + total_sec = 0.0 + total_text = 0.0 + total_num = 0 + for subfolder, _, filelist in sorted(os.walk(data_dir)): text_filelist = [ filename for filename in filelist if filename.endswith('trans.txt') @@ -80,10 +84,24 @@ def create_manifest(data_dir, manifest_path): 'text': text })) + + total_sec += duration + total_text += len(text) + total_num += 1 + with codecs.open(manifest_path, 'w', 'utf-8') as out_file: for line in json_lines: out_file.write(line + '\n') + subset = os.path.splitext(manifest_path)[1] + with open(subset + '.meta', 'w') as f: + print(f"{subset}:", file=f) + print(f"{total_num} utts", file=f) + print(f"{total_sec / (60*60)} h", file=f) + print(f"{total_text} text", file=f) + print(f"{total_text / total_sec} text/sec", file=f) + print(f"{total_sec / total_num} sec/utt", file=f) + def prepare_dataset(url, md5sum, target_dir, manifest_path): """Download, unpack and create summmary manifest file. diff --git a/examples/dataset/multi_cn/README.md b/examples/dataset/multi_cn/README.md new file mode 100644 index 000000000..d59b11b6d --- /dev/null +++ b/examples/dataset/multi_cn/README.md @@ -0,0 +1,11 @@ +# multi-cn + +This is a Chinese speech recognition recipe that trains on all Chinese corpora on OpenSLR, including: + +* Aidatatang (140 hours) +* Aishell (151 hours) +* MagicData (712 hours) +* Primewords (99 hours) +* ST-CMDS (110 hours) +* THCHS-30 (26 hours) +* optional AISHELL2 (~1000 hours) if available diff --git a/examples/dataset/primewords/README.md b/examples/dataset/primewords/README.md new file mode 100644 index 000000000..a4f1ed65d --- /dev/null +++ b/examples/dataset/primewords/README.md @@ -0,0 +1,6 @@ +# [Primewords](http://www.openslr.org/47/) + +This free Chinese Mandarin speech corpus set is released by Shanghai Primewords Information Technology Co., Ltd. +The corpus is recorded by smart mobile phones from 296 native Chinese speakers. The transcription accuracy is larger than 98%, at the confidence level of 95%. It is free for academic use. + +The mapping between the transcript and utterance is given in JSON format. diff --git a/examples/dataset/st-cmds/README.md b/examples/dataset/st-cmds/README.md new file mode 100644 index 000000000..c7ae50e59 --- /dev/null +++ b/examples/dataset/st-cmds/README.md @@ -0,0 +1 @@ +# [FreeST](http://www.openslr.org/38/) diff --git a/examples/dataset/thchs30/.gitignore b/examples/dataset/thchs30/.gitignore new file mode 100644 index 000000000..b94cd7e40 --- /dev/null +++ b/examples/dataset/thchs30/.gitignore @@ -0,0 +1,6 @@ +*.tgz +manifest.* +data_thchs30 +resource +test-noise +*.meta diff --git a/examples/dataset/thchs30/README.md b/examples/dataset/thchs30/README.md new file mode 100644 index 000000000..6b59d663a --- /dev/null +++ b/examples/dataset/thchs30/README.md @@ -0,0 +1,55 @@ +# [THCHS30](http://www.openslr.org/18/) + +This is the *data part* of the `THCHS30 2015` acoustic data +& scripts dataset. + +The dataset is described in more detail in the paper ``THCHS-30 : A Free +Chinese Speech Corpus`` by Dong Wang, Xuewei Zhang. + +A paper (if it can be called a paper) 13 years ago regarding the database: + +Dong Wang, Dalei Wu, Xiaoyan Zhu, ``TCMSD: A new Chinese Continuous Speech Database``, +International Conference on Chinese Computing (ICCC'01), 2001, Singapore. + +The layout of this data pack is the following: + + ``data`` + ``*.wav`` + audio data + + ``*.wav.trn`` + transcriptions + + ``{train,dev,test}`` + contain symlinks into the ``data`` directory for both audio and + transcription files. Contents of these directories define the + train/dev/test split of the data. + + ``{lm_word}`` + ``word.3gram.lm`` + trigram LM based on word + ``lexicon.txt`` + lexicon based on word + + ``{lm_phone}`` + ``phone.3gram.lm`` + trigram LM based on phone + ``lexicon.txt`` + lexicon based on phone + + ``README.TXT`` + this file + + +Data statistics +=============== + +Statistics for the data are as follows: + + =========== ========== ========== =========== + **dataset** **audio** **#sents** **#words** + =========== ========== ========== =========== + train 25 10,000 198,252 + dev 2:14 893 17,743 + test 6:15 2,495 49,085 + =========== ========== ========== =========== diff --git a/examples/dataset/thchs30/thchs30.py b/examples/dataset/thchs30/thchs30.py new file mode 100644 index 000000000..d03e3a22e --- /dev/null +++ b/examples/dataset/thchs30/thchs30.py @@ -0,0 +1,184 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Prepare THCHS-30 mandarin dataset + +Download, unpack and create manifest files. +Manifest file is a json-format file with each line containing the +meta data (i.e. audio filepath, transcript and audio duration) +of each audio file in the data set. +""" +import argparse +import codecs +import json +import os +from multiprocessing.pool import Pool +from pathlib import Path + +import soundfile + +from utils.utility import download +from utils.utility import unpack + +DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') + +URL_ROOT = 'http://www.openslr.org/resources/18' +# URL_ROOT = 'https://openslr.magicdatatech.com/resources/18' +DATA_URL = URL_ROOT + '/data_thchs30.tgz' +TEST_NOISE_URL = URL_ROOT + '/test-noise.tgz' +RESOURCE_URL = URL_ROOT + '/resource.tgz' +MD5_DATA = '2d2252bde5c8429929e1841d4cb95e90' +MD5_TEST_NOISE = '7e8a985fb965b84141b68c68556c2030' +MD5_RESOURCE = 'c0b2a565b4970a0c4fe89fefbf2d97e1' + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument( + "--target_dir", + default=DATA_HOME + "/THCHS30", + type=str, + help="Directory to save the dataset. (default: %(default)s)") +parser.add_argument( + "--manifest_prefix", + default="manifest", + type=str, + help="Filepath prefix for output manifests. (default: %(default)s)") +args = parser.parse_args() + + +def read_trn(filepath): + """read trn file. + word text in first line. + syllable text in second line. + phoneme text in third line. + + Args: + filepath (str): trn path. + + Returns: + list(str): (word, syllable, phone) + """ + texts = [] + with open(filepath, 'r') as f: + lines = f.read().strip().split('\n') + assert len(lines) == 3, lines + # charactor text, remove withespace + texts.append(''.join(lines[0].split())) + texts.extend(lines[1:]) + return texts + + +def resolve_symlink(filepath): + """resolve symlink which content is norm file. + + Args: + filepath (str): norm file symlink. + """ + sym_path = Path(filepath) + relative_link = sym_path.read_text().strip() + relative = Path(relative_link) + relpath = sym_path.parent / relative + return relpath.resolve() + + +def create_manifest(data_dir, manifest_path_prefix): + print("Creating manifest %s ..." % manifest_path_prefix) + json_lines = [] + data_types = ['train', 'dev', 'test'] + for dtype in data_types: + del json_lines[:] + total_sec = 0.0 + total_text = 0.0 + total_num = 0 + + audio_dir = os.path.join(data_dir, dtype) + for subfolder, _, filelist in sorted(os.walk(audio_dir)): + for fname in filelist: + file_path = os.path.join(subfolder, fname) + if file_path.endswith('.wav'): + audio_path = os.path.abspath(file_path) + text_path = resolve_symlink(audio_path + '.trn') + else: + continue + + assert os.path.exists(audio_path) and os.path.exists(text_path) + + audio_id = os.path.basename(audio_path)[:-4] + word_text, syllable_text, phone_text = read_trn(text_path) + audio_data, samplerate = soundfile.read(audio_path) + duration = float(len(audio_data) / samplerate) + + # not dump alignment infos + json_lines.append( + json.dumps( + { + 'utt': audio_id, + 'feat': audio_path, + 'feat_shape': (duration, ), # second + 'text': word_text, # charactor + 'syllable': syllable_text, + 'phone': phone_text, + }, + ensure_ascii=False)) + + total_sec += duration + total_text += len(word_text) + total_num += 1 + + manifest_path = manifest_path_prefix + '.' + dtype + with codecs.open(manifest_path, 'w', 'utf-8') as fout: + for line in json_lines: + fout.write(line + '\n') + + with open(dtype + '.meta', 'w') as f: + print(f"{dtype}:", file=f) + print(f"{total_num} utts", file=f) + print(f"{total_sec / (60*60)} h", file=f) + print(f"{total_text} text", file=f) + print(f"{total_text / total_sec} text/sec", file=f) + print(f"{total_sec / total_num} sec/utt", file=f) + + +def prepare_dataset(url, md5sum, target_dir, manifest_path, subset): + """Download, unpack and create manifest file.""" + datadir = os.path.join(target_dir, subset) + if not os.path.exists(datadir): + filepath = download(url, md5sum, target_dir) + unpack(filepath, target_dir) + else: + print("Skip downloading and unpacking. Data already exists in %s." % + target_dir) + + if subset == 'data_thchs30': + create_manifest(datadir, manifest_path) + + +def main(): + if args.target_dir.startswith('~'): + args.target_dir = os.path.expanduser(args.target_dir) + + tasks = [ + (DATA_URL, MD5_DATA, args.target_dir, args.manifest_prefix, + "data_thchs30"), + (TEST_NOISE_URL, MD5_TEST_NOISE, args.target_dir, args.manifest_prefix, + "test-noise"), + (RESOURCE_URL, MD5_RESOURCE, args.target_dir, args.manifest_prefix, + "resource"), + ] + with Pool(7) as pool: + pool.starmap(prepare_dataset, tasks) + + print("Data download and manifest prepare done!") + + +if __name__ == '__main__': + main() diff --git a/examples/dataset/timit/.gitignore b/examples/dataset/timit/.gitignore new file mode 100644 index 000000000..9a3f42281 --- /dev/null +++ b/examples/dataset/timit/.gitignore @@ -0,0 +1,4 @@ +TIMIT.* +TIMIT +manifest.* +*.meta diff --git a/examples/dataset/timit/timit.py b/examples/dataset/timit/timit.py new file mode 100644 index 000000000..222d9af30 --- /dev/null +++ b/examples/dataset/timit/timit.py @@ -0,0 +1,239 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Prepare Librispeech ASR datasets. + +Download, unpack and create manifest files. +Manifest file is a json-format file with each line containing the +meta data (i.e. audio filepath, transcript and audio duration) +of each audio file in the data set. +""" +import argparse +import codecs +import json +import os +import re +import string +from pathlib import Path + +import soundfile + +from utils.utility import unzip + +URL_ROOT = "" +MD5_DATA = "45c68037c7fdfe063a43c851f181fb2d" + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument( + "--target_dir", + default='~/.cache/paddle/dataset/speech/timit', + type=str, + help="Directory to save the dataset. (default: %(default)s)") +parser.add_argument( + "--manifest_prefix", + default="manifest", + type=str, + help="Filepath prefix for output manifests. (default: %(default)s)") +args = parser.parse_args() + +#: A string containing Chinese punctuation marks (non-stops). +non_stops = ( + # Fullwidth ASCII variants + '\uFF02\uFF03\uFF04\uFF05\uFF06\uFF07\uFF08\uFF09\uFF0A\uFF0B\uFF0C\uFF0D' + '\uFF0F\uFF1A\uFF1B\uFF1C\uFF1D\uFF1E\uFF20\uFF3B\uFF3C\uFF3D\uFF3E\uFF3F' + '\uFF40\uFF5B\uFF5C\uFF5D\uFF5E\uFF5F\uFF60' + + # Halfwidth CJK punctuation + '\uFF62\uFF63\uFF64' + + # CJK symbols and punctuation + '\u3000\u3001\u3003' + + # CJK angle and corner brackets + '\u3008\u3009\u300A\u300B\u300C\u300D\u300E\u300F\u3010\u3011' + + # CJK brackets and symbols/punctuation + '\u3014\u3015\u3016\u3017\u3018\u3019\u301A\u301B\u301C\u301D\u301E\u301F' + + # Other CJK symbols + '\u3030' + + # Special CJK indicators + '\u303E\u303F' + + # Dashes + '\u2013\u2014' + + # Quotation marks and apostrophe + '\u2018\u2019\u201B\u201C\u201D\u201E\u201F' + + # General punctuation + '\u2026\u2027' + + # Overscores and underscores + '\uFE4F' + + # Small form variants + '\uFE51\uFE54' + + # Latin punctuation + '\u00B7') + +#: A string of Chinese stops. +stops = ( + '\uFF01' # Fullwidth exclamation mark + '\uFF1F' # Fullwidth question mark + '\uFF61' # Halfwidth ideographic full stop + '\u3002' # Ideographic full stop +) + +#: A string containing all Chinese punctuation. +punctuation = non_stops + stops + + +def tn(text): + # lower text + text = text.lower() + # remove punc + text = re.sub(f'[{punctuation}{string.punctuation}]', "", text) + return text + + +def read_txt(filepath: str) -> str: + with open(filepath, 'r') as f: + line = f.read().strip().split(maxsplit=2)[2] + return tn(line) + + +def read_algin(filepath: str) -> str: + """read word or phone alignment file. + + + Args: + filepath (str): [description] + + Returns: + str: token sepearte by + """ + aligns = [] # (start, end, token) + with open(filepath, 'r') as f: + for line in f: + items = line.strip().split() + # for phone: (Note: beginning and ending silence regions are marked with h#) + if items[2].strip() == 'h#': + continue + aligns.append(items) + return ' '.join([item[2] for item in aligns]) + + +def create_manifest(data_dir, manifest_path_prefix): + """Create a manifest json file summarizing the data set, with each line + containing the meta data (i.e. audio filepath, transcription text, audio + duration) of each audio file within the data set. + """ + print("Creating manifest %s ..." % manifest_path_prefix) + json_lines = [] + utts = set() + + data_types = ['TRAIN', 'TEST'] + for dtype in data_types: + del json_lines[:] + total_sec = 0.0 + total_text = 0.0 + total_num = 0 + + audio_dir = Path(os.path.join(data_dir, dtype)) + for fname in sorted(audio_dir.rglob('*.WAV')): + audio_path = fname.resolve() # .WAV + audio_id = audio_path.stem + # if uttid exits, then skipped + if audio_id in utts: + continue + + utts.add(audio_id) + text_path = audio_path.with_suffix('.TXT') + phone_path = audio_path.with_suffix('.PHN') + word_path = audio_path.with_suffix('.WRD') + + audio_data, samplerate = soundfile.read( + str(audio_path), dtype='int16') + duration = float(len(audio_data) / samplerate) + word_text = read_txt(text_path) + phone_text = read_algin(phone_path) + + gender_spk = str(audio_path.parent.stem) + spk = gender_spk[1:] + gender = gender_spk[0] + utt_id = '_'.join([spk, gender, audio_id]) + # not dump alignment infos + json_lines.append( + json.dumps( + { + 'utt': utt_id, + 'feat': str(audio_path), + 'feat_shape': (duration, ), # second + 'text': word_text, # word + 'phone': phone_text, + 'spk': spk, + 'gender': gender, + }, + ensure_ascii=False)) + + total_sec += duration + total_text += len(word_text.split()) + total_num += 1 + + manifest_path = manifest_path_prefix + '.' + dtype.lower() + with codecs.open(manifest_path, 'w', 'utf-8') as fout: + for line in json_lines: + fout.write(line + '\n') + + with open(dtype.lower() + '.meta', 'w') as f: + print(f"{dtype}:", file=f) + print(f"{total_num} utts", file=f) + print(f"{total_sec / (60*60)} h", file=f) + print(f"{total_text} text", file=f) + print(f"{total_text / total_sec} text/sec", file=f) + print(f"{total_sec / total_num} sec/utt", file=f) + + +def prepare_dataset(url, md5sum, target_dir, manifest_path): + """Download, unpack and create summmary manifest file. + """ + filepath = os.path.join(target_dir, "TIMIT.zip") + if not os.path.exists(filepath): + print(f"Please download TIMIT.zip into {target_dir}.") + raise FileNotFoundError + + if not os.path.exists(os.path.join(target_dir, "TIMIT")): + # check md5sum + assert check_md5sum(filepath, md5sum) + # unpack + unzip(filepath, target_dir) + else: + print("Skip downloading and unpacking. Data already exists in %s." % + target_dir) + # create manifest json file + create_manifest(os.path.join(target_dir, "TIMIT"), manifest_path) + + +def main(): + if args.target_dir.startswith('~'): + args.target_dir = os.path.expanduser(args.target_dir) + + prepare_dataset(URL_ROOT, MD5_DATA, args.target_dir, args.manifest_prefix) + print("Data download and manifest prepare done!") + + +if __name__ == '__main__': + main() diff --git a/examples/librispeech/s0/README.md b/examples/librispeech/s0/README.md index dde288bdd..76aa5e78a 100644 --- a/examples/librispeech/s0/README.md +++ b/examples/librispeech/s0/README.md @@ -2,8 +2,8 @@ ## Deepspeech2 -| Model | Params | Release | Config | Test set | Loss | WER | -| --- | --- | --- | --- | --- | --- | --- | -| DeepSpeech2 | 42.96M | 2.1.0 | conf/deepspeech2.yaml | 15.184467315673828 | test-clean | 0.072154 | -| DeepSpeech2 | 42.96M | 2.0.0 | conf/deepspeech2.yaml | - | test-clean | 0.073973 | -| DeepSpeech2 | 42.96M | 1.8.5 | - | test-clean | - | 0.074939 | +| Model | Params | release | Config | Test set | Loss | WER | +| --- | --- | --- | --- | --- | --- | --- | +| DeepSpeech2 | 42.96M | 2.1.0 | conf/deepspeech2.yaml | 15.184467315673828 | test-clean | 0.072154 | +| DeepSpeech2 | 42.96M | 2.0.0 | conf/deepspeech2.yaml | - | test-clean | 0.073973 | +| DeepSpeech2 | 42.96M | 1.8.5 | - | test-clean | - | 0.074939 | diff --git a/examples/librispeech/s0/conf/deepspeech2.yaml b/examples/librispeech/s0/conf/deepspeech2.yaml index b419cbe26..acee94c3e 100644 --- a/examples/librispeech/s0/conf/deepspeech2.yaml +++ b/examples/librispeech/s0/conf/deepspeech2.yaml @@ -48,6 +48,9 @@ training: weight_decay: 1e-06 global_grad_clip: 5.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: batch_size: 128 diff --git a/examples/librispeech/s1/README.md b/examples/librispeech/s1/README.md index 73f6156d9..5e23c0ab5 100644 --- a/examples/librispeech/s1/README.md +++ b/examples/librispeech/s1/README.md @@ -2,17 +2,17 @@ ## Conformer -| Model | Config | Augmentation| Test set | Decode method | Loss | WER | -| --- | --- | --- | --- | --- | --- | --- | -| conformer | conf/conformer.yaml | spec_aug + shift | test-all | attention | test-all 6.35 | 0.057117 | -| conformer | conf/conformer.yaml | spec_aug + shift | test-clean | attention | test-all 6.35 | 0.030162 | -| conformer | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_greedy_search | test-all 6.35 | 0.037910 | -| conformer | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | test-all 6.35 | 0.037761 | -| conformer | conf/conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | test-all 6.35 | 0.032115 | +| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER | +| --- | --- | --- | --- | --- | --- | --- | --- | +| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-all | attention | 6.35 | 0.057117 | +| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention | 6.35 | 0.030162 | +| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_greedy_search | 6.35 | 0.037910 | +| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | 6.35 | 0.037761 | +| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | 6.35 | 0.032115 | ## Transformer -| Model | Config | Augmentation| Test set | Decode method | Loss | WER | -| --- | --- | --- | --- | --- | --- | --- | -| transformer | conf/transformer.yaml | spec_aug + shift | test-all | attention | test-all 6.98 | 0.066500 | -| transformer | conf/transformer.yaml | spec_aug + shift | test-clean | attention | test-all 6.98 | 0.036 | +| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER | +| --- | --- | --- | --- | --- | --- | --- | --- | +| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-all | attention | 6.98 | 0.066500 | +| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean | attention | 6.98 | 0.036 | diff --git a/examples/librispeech/s1/conf/chunk_confermer.yaml b/examples/librispeech/s1/conf/chunk_confermer.yaml index ec945a188..5af689594 100644 --- a/examples/librispeech/s1/conf/chunk_confermer.yaml +++ b/examples/librispeech/s1/conf/chunk_confermer.yaml @@ -3,18 +3,20 @@ data: train_manifest: data/manifest.train dev_manifest: data/manifest.dev test_manifest: data/manifest.test - vocab_filepath: data/vocab.txt - unit_type: 'spm' - spm_model_prefix: 'data/bpe_unigram_5000' - mean_std_filepath: "" - augmentation_config: conf/augmentation.json - batch_size: 4 min_input_len: 0.5 max_input_len: 20.0 min_output_len: 0.0 max_output_len: 400.0 min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 + +collator: + vocab_filepath: data/vocab.txt + unit_type: 'spm' + spm_model_prefix: 'data/bpe_unigram_5000' + mean_std_filepath: "" + augmentation_config: conf/augmentation.json + batch_size: 16 raw_wav: True # use raw_wav or kaldi feature specgram_type: fbank #linear, mfcc, fbank feat_dim: 80 @@ -80,7 +82,7 @@ model: training: n_epoch: 120 - accum_grad: 1 + accum_grad: 8 global_grad_clip: 5.0 optim: adam optim_conf: @@ -91,6 +93,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/librispeech/s1/conf/chunk_transformer.yaml b/examples/librispeech/s1/conf/chunk_transformer.yaml index 3939ffc68..f782a0373 100644 --- a/examples/librispeech/s1/conf/chunk_transformer.yaml +++ b/examples/librispeech/s1/conf/chunk_transformer.yaml @@ -3,18 +3,20 @@ data: train_manifest: data/manifest.train dev_manifest: data/manifest.dev test_manifest: data/manifest.test - vocab_filepath: data/vocab.txt - unit_type: 'spm' - spm_model_prefix: 'data/bpe_unigram_5000' - mean_std_filepath: "" - augmentation_config: conf/augmentation.json - batch_size: 64 min_input_len: 0.5 # second max_input_len: 20.0 # second min_output_len: 0.0 # tokens max_output_len: 400.0 # tokens min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 + +collator: + vocab_filepath: data/vocab.txt + unit_type: 'spm' + spm_model_prefix: 'data/bpe_unigram_5000' + mean_std_filepath: "" + augmentation_config: conf/augmentation.json + batch_size: 64 raw_wav: True # use raw_wav or kaldi feature specgram_type: fbank #linear, mfcc, fbank feat_dim: 80 @@ -84,6 +86,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: @@ -103,6 +108,6 @@ decoding: # >0: for decoding, use fixed chunk size as set. # 0: used for training, it's prohibited here. num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1. - simulate_streaming: False # simulate streaming inference. Defaults to False. + simulate_streaming: true # simulate streaming inference. Defaults to False. diff --git a/examples/librispeech/s1/conf/conformer.yaml b/examples/librispeech/s1/conf/conformer.yaml index 8f8bf4539..955b6108b 100644 --- a/examples/librispeech/s1/conf/conformer.yaml +++ b/examples/librispeech/s1/conf/conformer.yaml @@ -3,18 +3,20 @@ data: train_manifest: data/manifest.train dev_manifest: data/manifest.dev test_manifest: data/manifest.test-clean - vocab_filepath: data/vocab.txt - unit_type: 'spm' - spm_model_prefix: 'data/bpe_unigram_5000' - mean_std_filepath: "" - augmentation_config: conf/augmentation.json - batch_size: 16 min_input_len: 0.5 # seconds max_input_len: 20.0 # seconds min_output_len: 0.0 # tokens max_output_len: 400.0 # tokens min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 + +collator: + vocab_filepath: data/vocab.txt + unit_type: 'spm' + spm_model_prefix: 'data/bpe_unigram_5000' + mean_std_filepath: "" + augmentation_config: conf/augmentation.json + batch_size: 16 raw_wav: True # use raw_wav or kaldi feature specgram_type: fbank #linear, mfcc, fbank feat_dim: 80 @@ -87,6 +89,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/librispeech/s1/conf/transformer.yaml b/examples/librispeech/s1/conf/transformer.yaml index a094b0fba..8a769dca4 100644 --- a/examples/librispeech/s1/conf/transformer.yaml +++ b/examples/librispeech/s1/conf/transformer.yaml @@ -3,18 +3,20 @@ data: train_manifest: data/manifest.train dev_manifest: data/manifest.dev test_manifest: data/manifest.test-clean - vocab_filepath: data/vocab.txt - unit_type: 'spm' - spm_model_prefix: 'data/bpe_unigram_5000' - mean_std_filepath: "" - augmentation_config: conf/augmentation.json - batch_size: 64 min_input_len: 0.5 # second max_input_len: 20.0 # second min_output_len: 0.0 # tokens max_output_len: 400.0 # tokens min_output_input_ratio: 0.05 max_output_input_ratio: 10.0 + +collator: + vocab_filepath: data/vocab.txt + unit_type: 'spm' + spm_model_prefix: 'data/bpe_unigram_5000' + mean_std_filepath: "" + augmentation_config: conf/augmentation.json + batch_size: 64 raw_wav: True # use raw_wav or kaldi feature specgram_type: fbank #linear, mfcc, fbank feat_dim: 80 @@ -82,6 +84,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 100 + checkpoint: + kbest_n: 50 + latest_n: 5 decoding: diff --git a/examples/librispeech/s1/local/align.sh b/examples/librispeech/s1/local/align.sh new file mode 100755 index 000000000..926cb9397 --- /dev/null +++ b/examples/librispeech/s1/local/align.sh @@ -0,0 +1,43 @@ +#! /usr/bin/env bash + +if [ $# != 2 ];then + echo "usage: ${0} config_path ckpt_path_prefix" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +device=gpu +if [ ngpu == 0 ];then + device=cpu +fi +config_path=$1 +ckpt_prefix=$2 + +ckpt_name=$(basename ${ckpt_prefxi}) + +mkdir -p exp + + + +batch_size=1 +output_dir=${ckpt_prefix} +mkdir -p ${output_dir} + +# align dump in `result_file` +# .tier, .TextGrid dump in `dir of result_file` +python3 -u ${BIN_DIR}/alignment.py \ +--device ${device} \ +--nproc 1 \ +--config ${config_path} \ +--result_file ${output_dir}/${type}.align \ +--checkpoint_path ${ckpt_prefix} \ +--opts decoding.batch_size ${batch_size} + +if [ $? -ne 0 ]; then + echo "Failed in ctc alignment!" + exit 1 +fi + +exit 0 diff --git a/examples/librispeech/s1/run.sh b/examples/librispeech/s1/run.sh index 65194d902..b81e8dcfd 100755 --- a/examples/librispeech/s1/run.sh +++ b/examples/librispeech/s1/run.sh @@ -33,6 +33,11 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # ctc alignment of test data + CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 +fi + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then # export ckpt avg_n CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit fi diff --git a/examples/ngram_lm/README.md b/examples/ngram_lm/s0/README.md similarity index 100% rename from examples/ngram_lm/README.md rename to examples/ngram_lm/s0/README.md diff --git a/examples/ngram_lm/data/README.md b/examples/ngram_lm/s0/data/README.md similarity index 100% rename from examples/ngram_lm/data/README.md rename to examples/ngram_lm/s0/data/README.md diff --git a/examples/ngram_lm/data/custom_confusion.txt b/examples/ngram_lm/s0/data/custom_confusion.txt similarity index 100% rename from examples/ngram_lm/data/custom_confusion.txt rename to examples/ngram_lm/s0/data/custom_confusion.txt diff --git a/examples/ngram_lm/data/text_correct.txt b/examples/ngram_lm/s0/data/text_correct.txt similarity index 100% rename from examples/ngram_lm/data/text_correct.txt rename to examples/ngram_lm/s0/data/text_correct.txt diff --git a/examples/ngram_lm/local/build_zh_lm.sh b/examples/ngram_lm/s0/local/build_zh_lm.sh similarity index 100% rename from examples/ngram_lm/local/build_zh_lm.sh rename to examples/ngram_lm/s0/local/build_zh_lm.sh diff --git a/examples/ngram_lm/local/download_lm_zh.sh b/examples/ngram_lm/s0/local/download_lm_zh.sh similarity index 100% rename from examples/ngram_lm/local/download_lm_zh.sh rename to examples/ngram_lm/s0/local/download_lm_zh.sh diff --git a/examples/ngram_lm/local/kenlm_score_test.py b/examples/ngram_lm/s0/local/kenlm_score_test.py similarity index 100% rename from examples/ngram_lm/local/kenlm_score_test.py rename to examples/ngram_lm/s0/local/kenlm_score_test.py diff --git a/examples/ngram_lm/path.sh b/examples/ngram_lm/s0/path.sh similarity index 69% rename from examples/ngram_lm/path.sh rename to examples/ngram_lm/s0/path.sh index 84e2de7d0..5f580bc4b 100644 --- a/examples/ngram_lm/path.sh +++ b/examples/ngram_lm/s0/path.sh @@ -1,4 +1,4 @@ -export MAIN_ROOT=${PWD}/../../ +export MAIN_ROOT=${PWD}/../../../ export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} export LC_ALL=C @@ -7,4 +7,4 @@ export LC_ALL=C export PYTHONIOENCODING=UTF-8 export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} -export LD_LIBRARY_PATH=/usr/local/lib/:${LD_LIBRARY_PATH} \ No newline at end of file +export LD_LIBRARY_PATH=/usr/local/lib/:${LD_LIBRARY_PATH} diff --git a/examples/ngram_lm/requirements.txt b/examples/ngram_lm/s0/requirements.txt similarity index 100% rename from examples/ngram_lm/requirements.txt rename to examples/ngram_lm/s0/requirements.txt diff --git a/examples/ngram_lm/run.sh b/examples/ngram_lm/s0/run.sh similarity index 100% rename from examples/ngram_lm/run.sh rename to examples/ngram_lm/s0/run.sh diff --git a/examples/tiny/s0/conf/deepspeech2.yaml b/examples/tiny/s0/conf/deepspeech2.yaml index 6737d1b75..ea433f341 100644 --- a/examples/tiny/s0/conf/deepspeech2.yaml +++ b/examples/tiny/s0/conf/deepspeech2.yaml @@ -43,12 +43,16 @@ model: share_rnn_weights: True training: - n_epoch: 24 + n_epoch: 10 lr: 1e-5 lr_decay: 1.0 weight_decay: 1e-06 global_grad_clip: 5.0 log_interval: 1 + checkpoint: + kbest_n: 3 + latest_n: 2 + decoding: batch_size: 128 diff --git a/examples/tiny/s1/conf/chunk_confermer.yaml b/examples/tiny/s1/conf/chunk_confermer.yaml index 790066264..606300bdf 100644 --- a/examples/tiny/s1/conf/chunk_confermer.yaml +++ b/examples/tiny/s1/conf/chunk_confermer.yaml @@ -91,6 +91,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 1 + checkpoint: + kbest_n: 10 + latest_n: 1 decoding: diff --git a/examples/tiny/s1/conf/chunk_transformer.yaml b/examples/tiny/s1/conf/chunk_transformer.yaml index aa2b145a6..72d368485 100644 --- a/examples/tiny/s1/conf/chunk_transformer.yaml +++ b/examples/tiny/s1/conf/chunk_transformer.yaml @@ -84,6 +84,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 1 + checkpoint: + kbest_n: 10 + latest_n: 1 decoding: diff --git a/examples/tiny/s1/conf/conformer.yaml b/examples/tiny/s1/conf/conformer.yaml index 3813daa04..a6f730501 100644 --- a/examples/tiny/s1/conf/conformer.yaml +++ b/examples/tiny/s1/conf/conformer.yaml @@ -87,6 +87,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 1 + checkpoint: + kbest_n: 10 + latest_n: 1 decoding: diff --git a/examples/tiny/s1/conf/transformer.yaml b/examples/tiny/s1/conf/transformer.yaml index 250995faa..71cbdde7f 100644 --- a/examples/tiny/s1/conf/transformer.yaml +++ b/examples/tiny/s1/conf/transformer.yaml @@ -84,6 +84,9 @@ training: warmup_steps: 25000 lr_decay: 1.0 log_interval: 1 + checkpoint: + kbest_n: 10 + latest_n: 1 decoding: diff --git a/examples/tiny/s1/local/align.sh b/examples/tiny/s1/local/align.sh new file mode 100755 index 000000000..926cb9397 --- /dev/null +++ b/examples/tiny/s1/local/align.sh @@ -0,0 +1,43 @@ +#! /usr/bin/env bash + +if [ $# != 2 ];then + echo "usage: ${0} config_path ckpt_path_prefix" + exit -1 +fi + +ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') +echo "using $ngpu gpus..." + +device=gpu +if [ ngpu == 0 ];then + device=cpu +fi +config_path=$1 +ckpt_prefix=$2 + +ckpt_name=$(basename ${ckpt_prefxi}) + +mkdir -p exp + + + +batch_size=1 +output_dir=${ckpt_prefix} +mkdir -p ${output_dir} + +# align dump in `result_file` +# .tier, .TextGrid dump in `dir of result_file` +python3 -u ${BIN_DIR}/alignment.py \ +--device ${device} \ +--nproc 1 \ +--config ${config_path} \ +--result_file ${output_dir}/${type}.align \ +--checkpoint_path ${ckpt_prefix} \ +--opts decoding.batch_size ${batch_size} + +if [ $? -ne 0 ]; then + echo "Failed in ctc alignment!" + exit 1 +fi + +exit 0 diff --git a/examples/tiny/s1/run.sh b/examples/tiny/s1/run.sh index b148869b7..41f845b05 100755 --- a/examples/tiny/s1/run.sh +++ b/examples/tiny/s1/run.sh @@ -34,6 +34,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # ctc alignment of test data + CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1 +fi + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then # export ckpt avg_n - CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit + CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit fi + diff --git a/speechnn/CMakeLists.txt b/speechnn/CMakeLists.txt index e69de29bb..878374bab 100644 --- a/speechnn/CMakeLists.txt +++ b/speechnn/CMakeLists.txt @@ -0,0 +1,77 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +project(deepspeech VERSION 0.1) + +set(CMAKE_VERBOSE_MAKEFILE on) +# set std-14 +set(CMAKE_CXX_STANDARD 14) + +# include file +include(FetchContent) +include(ExternalProject) +# fc_patch dir +set(FETCHCONTENT_QUIET off) +get_filename_component(fc_patch "fc_patch" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}") +set(FETCHCONTENT_BASE_DIR ${fc_patch}) + + +############################################################################### +# Option Configurations +############################################################################### +# option configurations +option(TEST_DEBUG "option for debug" OFF) + + +############################################################################### +# Include third party +############################################################################### +# #example for include third party +# FetchContent_Declare() +# # FetchContent_MakeAvailable was not added until CMake 3.14 +# FetchContent_MakeAvailable() +# include_directories() + +# ABSEIL-CPP +include(FetchContent) +FetchContent_Declare( + absl + GIT_REPOSITORY "https://github.com/abseil/abseil-cpp.git" + GIT_TAG "20210324.1" +) +FetchContent_MakeAvailable(absl) + +# libsndfile +include(FetchContent) +FetchContent_Declare( + libsndfile + GIT_REPOSITORY "https://github.com/libsndfile/libsndfile.git" + GIT_TAG "1.0.31" +) +FetchContent_MakeAvailable(libsndfile) + + +############################################################################### +# Add local library +############################################################################### +# system lib +find_package() +# if dir have CmakeLists.txt +add_subdirectory() +# if dir do not have CmakeLists.txt +add_library(lib_name STATIC file.cc) +target_link_libraries(lib_name item0 item1) +add_dependencies(lib_name depend-target) + + +############################################################################### +# Library installation +############################################################################### +install() + + +############################################################################### +# Build binary file +############################################################################### +add_executable() +target_link_libraries() + diff --git a/speechnn/core/decoder/CMakeLists.txt b/speechnn/core/decoder/CMakeLists.txt index e69de29bb..259261bdf 100644 --- a/speechnn/core/decoder/CMakeLists.txt +++ b/speechnn/core/decoder/CMakeLists.txt @@ -0,0 +1,2 @@ +aux_source_directory(. DIR_LIB_SRCS) +add_library(decoder STATIC ${DIR_LIB_SRCS}) diff --git a/tools/Makefile b/tools/Makefile index dd5902373..94e5ea2f7 100644 --- a/tools/Makefile +++ b/tools/Makefile @@ -19,7 +19,7 @@ kenlm.done: apt-get install -y gcc-5 g++-5 && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-5 50 && update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-5 50 test -d kenlm || wget -O - https://kheafield.com/code/kenlm.tar.gz | tar xz mkdir -p kenlm/build && cd kenlm/build && cmake .. && make -j4 && make install - cd kenlm && python setup.py install + source venv/bin/activate; cd kenlm && python setup.py install touch kenlm.done sox.done: @@ -32,4 +32,4 @@ sox.done: soxbindings.done: test -d soxbindings || git clone https://github.com/pseeth/soxbindings.git source venv/bin/activate; cd soxbindings && python setup.py install - touch soxbindings.done \ No newline at end of file + touch soxbindings.done diff --git a/utils/utility.py b/utils/utility.py index 0333bc559..344900efa 100644 --- a/utils/utility.py +++ b/utils/utility.py @@ -14,9 +14,15 @@ import os import tarfile import zipfile +from typing import Text from paddle.dataset.common import md5file +__all__ = [ + "check_md5sum", "getfile_insensitive", "download_multi", "download", + "unpack", "unzip" +] + def getfile_insensitive(path): """Get the actual file path when given insensitive filename.""" @@ -54,6 +60,19 @@ def download(url, md5sum, target_dir): return filepath +def check_md5sum(filepath: Text, md5sum: Text) -> bool: + """check md5sum of file. + + Args: + filepath (Text): [description] + md5sum (Text): [description] + + Returns: + bool: same or not. + """ + return md5file(filepath) == md5sum + + def unpack(filepath, target_dir, rm_tar=False): """Unpack the file to the target_dir.""" print("Unpacking %s ..." % filepath)