Merge branch 'develop' into spec_aug2

pull/673/head
Blank 4 years ago committed by GitHub
commit f1acfe443d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -34,9 +34,12 @@ from deepspeech.models.u2 import U2Model
from deepspeech.training.gradclip import ClipGradByGlobalNormWithLog
from deepspeech.training.scheduler import WarmupLR
from deepspeech.training.trainer import Trainer
from deepspeech.utils import ctc_utils
from deepspeech.utils import error_rate
from deepspeech.utils import layer_tools
from deepspeech.utils import mp_tools
from deepspeech.utils import text_grid
from deepspeech.utils import utility
from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
@ -278,7 +281,15 @@ class U2Trainer(Trainer):
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator.from_config(config))
logger.info("Setup train/valid/test Dataloader!")
# return text token id
config.collator.keep_transcription_text = False
self.align_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator.from_config(config))
logger.info("Setup train/valid/test/align Dataloader!")
def setup_model(self):
config = self.config
@ -498,6 +509,73 @@ class U2Tester(U2Trainer):
except KeyboardInterrupt:
sys.exit(-1)
@paddle.no_grad()
def align(self):
if self.config.decoding.batch_size > 1:
logger.fatal('alignment mode must be running with batch_size == 1')
sys.exit(1)
# xxx.align
assert self.args.result_file and self.args.result_file.endswith(
'.align')
self.model.eval()
logger.info(f"Align Total Examples: {len(self.align_loader.dataset)}")
stride_ms = self.align_loader.collate_fn.stride_ms
token_dict = self.align_loader.collate_fn.vocab_list
with open(self.args.result_file, 'w') as fout:
# one example in batch
for i, batch in enumerate(self.align_loader):
key, feat, feats_length, target, target_length = batch
# 1. Encoder
encoder_out, encoder_mask = self.model._forward_encoder(
feat, feats_length) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
ctc_probs = self.model.ctc.log_softmax(
encoder_out) # (1, maxlen, vocab_size)
# 2. alignment
ctc_probs = ctc_probs.squeeze(0)
target = target.squeeze(0)
alignment = ctc_utils.forced_align(ctc_probs, target)
logger.info("align ids", key[0], alignment)
fout.write('{} {}\n'.format(key[0], alignment))
# 3. gen praat
# segment alignment
align_segs = text_grid.segment_alignment(alignment)
logger.info("align tokens", key[0], align_segs)
# IntervalTier, List["start end token\n"]
subsample = utility.get_subsample(self.config)
tierformat = text_grid.align_to_tierformat(
align_segs, subsample, token_dict)
# write tier
align_output_path = os.path.join(
os.path.dirname(self.args.result_file), "align")
tier_path = os.path.join(align_output_path, key[0] + ".tier")
with open(tier_path, 'w') as f:
f.writelines(tierformat)
# write textgrid
textgrid_path = os.path.join(align_output_path,
key[0] + ".TextGrid")
second_per_frame = 1. / (1000. /
stride_ms) # 25ms window, 10ms stride
second_per_example = (
len(alignment) + 1) * subsample * second_per_frame
text_grid.generate_textgrid(
maxtime=second_per_example,
intervals=tierformat,
output=textgrid_path)
def run_align(self):
self.resume_or_scratch()
try:
self.align()
except KeyboardInterrupt:
sys.exit(-1)
def load_inferspec(self):
"""infer model and input spec.
@ -511,10 +589,9 @@ class U2Tester(U2Trainer):
self.args.checkpoint_path)
feat_dim = self.test_loader.collate_fn.feature_size
input_spec = [
paddle.static.InputSpec(
shape=[None, feat_dim, None],
dtype='float32'), # audio, [B,D,T]
paddle.static.InputSpec(shape=[None],
paddle.static.InputSpec(shape=[1, None, feat_dim],
dtype='float32'), # audio, [B,T,D]
paddle.static.InputSpec(shape=[1],
dtype='int64'), # audio_length, [B]
]
return infer_model, input_spec

@ -599,26 +599,26 @@ class U2BaseModel(nn.Module):
best_index = i
return hyps[best_index][0]
@jit.export
#@jit.export
def subsampling_rate(self) -> int:
""" Export interface for c++ call, return subsampling_rate of the
model
"""
return self.encoder.embed.subsampling_rate
@jit.export
#@jit.export
def right_context(self) -> int:
""" Export interface for c++ call, return right_context of the model
"""
return self.encoder.embed.right_context
@jit.export
#@jit.export
def sos_symbol(self) -> int:
""" Export interface for c++ call, return sos symbol id of the model
"""
return self.sos
@jit.export
#@jit.export
def eos_symbol(self) -> int:
""" Export interface for c++ call, return eos symbol id of the model
"""
@ -654,12 +654,14 @@ class U2BaseModel(nn.Module):
xs, offset, required_cache_size, subsampling_cache,
elayers_output_cache, conformer_cnn_cache)
@jit.export
# @jit.export([
# paddle.static.InputSpec(shape=[1, None, feat_dim],dtype='float32'), # audio feat, [B,T,D]
# ])
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
""" Export interface for c++ call, apply linear transform and log
softmax before ctc
Args:
xs (paddle.Tensor): encoder output
xs (paddle.Tensor): encoder output, (B, T, D)
Returns:
paddle.Tensor: activation before ctc
"""
@ -894,7 +896,7 @@ class U2Model(U2BaseModel):
model = cls.from_config(config)
if checkpoint_path:
infos = checkpoint.load_parameters(
infos = checkpoint.Checkpoint().load_parameters(
model, checkpoint_path=checkpoint_path)
logger.info(f"checkpoint info: {infos}")
layer_tools.summary(model)

@ -18,8 +18,8 @@ import paddle
from paddle import distributed as dist
from tensorboardX import SummaryWriter
from deepspeech.utils import checkpoint
from deepspeech.utils import mp_tools
from deepspeech.utils.checkpoint import Checkpoint
from deepspeech.utils.log import Log
__all__ = ["Trainer"]
@ -139,7 +139,7 @@ class Trainer():
"epoch": self.epoch,
"lr": self.optimizer.get_lr()
})
checkpoint.save_parameters(self.checkpoint_dir, self.iteration
self.checkpoint.add_checkpoint(self.checkpoint_dir, self.iteration
if tag is None else tag, self.model,
self.optimizer, infos)
@ -151,7 +151,7 @@ class Trainer():
resume training.
"""
scratch = None
infos = checkpoint.load_parameters(
infos = self.checkpoint.load_latest_parameters(
self.model,
self.optimizer,
checkpoint_dir=self.checkpoint_dir,
@ -180,7 +180,7 @@ class Trainer():
from_scratch = self.resume_or_scratch()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init')
self.save(tag='init', infos=None)
self.lr_scheduler.step(self.iteration)
if self.parallel:
@ -263,6 +263,10 @@ class Trainer():
self.checkpoint_dir = checkpoint_dir
self.checkpoint = Checkpoint(
kbest_n=self.config.training.checkpoint.kbest_n,
latest_n=self.config.training.checkpoint.latest_n)
@mp_tools.rank_zero_only
def destory(self):
"""Close visualizer to avoid hanging after training"""

@ -11,9 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import json
import os
import re
from pathlib import Path
from typing import Text
from typing import Union
import paddle
@ -25,46 +28,58 @@ from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ["load_parameters", "save_parameters"]
__all__ = ["Checkpoint"]
def _load_latest_checkpoint(checkpoint_dir: str) -> int:
"""Get the iteration number corresponding to the latest saved checkpoint.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
Returns:
int: the latest iteration number. -1 for no checkpoint to load.
"""
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
if not os.path.isfile(checkpoint_record):
return -1
# Fetch the latest checkpoint index.
with open(checkpoint_record, "rt") as handle:
latest_checkpoint = handle.readlines()[-1].strip()
iteration = int(latest_checkpoint.split(":")[-1])
return iteration
class Checkpoint():
def __init__(self, kbest_n: int=5, latest_n: int=1):
self.best_records: Mapping[Path, float] = {}
self.latest_records = []
self.kbest_n = kbest_n
self.latest_n = latest_n
self._save_all = (kbest_n == -1)
def add_checkpoint(self,
checkpoint_dir,
tag_or_iteration: Union[int, Text],
model: paddle.nn.Layer,
optimizer: Optimizer=None,
infos: dict=None,
metric_type="val_loss"):
"""Save checkpoint in best_n and latest_n.
def _save_record(checkpoint_dir: str, iteration: int):
"""Save the iteration number of the latest model to be checkpoint record.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number.
Returns:
None
tag_or_iteration (int or str): the latest iteration(step or epoch) number or tag.
model (Layer): model to be checkpointed.
optimizer (Optimizer, optional): optimizer to be checkpointed.
infos (dict or None)): any info you want to save.
metric_type (str, optional): metric type. Defaults to "val_loss".
"""
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
# Update the latest checkpoint index.
with open(checkpoint_record, "a+") as handle:
handle.write("model_checkpoint_path:{}\n".format(iteration))
if (metric_type not in infos.keys()):
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
return
#save best
if self._should_save_best(infos[metric_type]):
self._save_best_checkpoint_and_update(
infos[metric_type], checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
#save latest
self._save_latest_checkpoint_and_update(
checkpoint_dir, tag_or_iteration, model, optimizer, infos)
if isinstance(tag_or_iteration, int):
self._save_checkpoint_record(checkpoint_dir, tag_or_iteration)
def load_parameters(model,
def load_parameters(self,
model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None):
"""Load a specific model checkpoint from disk.
checkpoint_path=None,
record_file="checkpoint_latest"):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
@ -73,21 +88,25 @@ def load_parameters(model,
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
record_file "checkpoint_latest" or "checkpoint_best"
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
configs = {}
if checkpoint_path is not None:
tag = os.path.basename(checkpoint_path).split(":")[-1]
elif checkpoint_dir is not None:
iteration = _load_latest_checkpoint(checkpoint_dir)
pass
elif checkpoint_dir is not None and record_file is not None:
# load checkpint from record file
checkpoint_record = os.path.join(checkpoint_dir, record_file)
iteration = self._load_checkpoint_idx(checkpoint_record)
if iteration == -1:
return configs
checkpoint_path = os.path.join(checkpoint_dir, "{}".format(iteration))
checkpoint_path = os.path.join(checkpoint_dir,
"{}".format(iteration))
else:
raise ValueError(
"At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!"
"At least one of 'checkpoint_path' or 'checkpoint_dir' should be specified!"
)
rank = dist.get_rank()
@ -110,9 +129,139 @@ def load_parameters(model,
configs = json.load(fin)
return configs
def load_latest_parameters(self,
model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
return self.load_parameters(model, optimizer, checkpoint_dir,
checkpoint_path, "checkpoint_latest")
@mp_tools.rank_zero_only
def save_parameters(checkpoint_dir: str,
def load_best_parameters(self,
model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path(prefix) and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
Returns:
configs (dict): epoch or step, lr and other meta info should be saved.
"""
return self.load_parameters(model, optimizer, checkpoint_dir,
checkpoint_path, "checkpoint_best")
def _should_save_best(self, metric: float) -> bool:
if not self._best_full():
return True
# already full
worst_record_path = max(self.best_records, key=self.best_records.get)
# worst_record_path = max(self.best_records.iteritems(), key=operator.itemgetter(1))[0]
worst_metric = self.best_records[worst_record_path]
return metric < worst_metric
def _best_full(self):
return (not self._save_all) and len(self.best_records) == self.kbest_n
def _latest_full(self):
return len(self.latest_records) == self.latest_n
def _save_best_checkpoint_and_update(self, metric, checkpoint_dir,
tag_or_iteration, model, optimizer,
infos):
# remove the worst
if self._best_full():
worst_record_path = max(self.best_records,
key=self.best_records.get)
self.best_records.pop(worst_record_path)
if (worst_record_path not in self.latest_records):
logger.info(
"remove the worst checkpoint: {}".format(worst_record_path))
self._del_checkpoint(checkpoint_dir, worst_record_path)
# add the new one
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
self.best_records[tag_or_iteration] = metric
def _save_latest_checkpoint_and_update(
self, checkpoint_dir, tag_or_iteration, model, optimizer, infos):
# remove the old
if self._latest_full():
to_del_fn = self.latest_records.pop(0)
if (to_del_fn not in self.best_records.keys()):
logger.info(
"remove the latest checkpoint: {}".format(to_del_fn))
self._del_checkpoint(checkpoint_dir, to_del_fn)
self.latest_records.append(tag_or_iteration)
self._save_parameters(checkpoint_dir, tag_or_iteration, model,
optimizer, infos)
def _del_checkpoint(self, checkpoint_dir, tag_or_iteration):
checkpoint_path = os.path.join(checkpoint_dir,
"{}".format(tag_or_iteration))
for filename in glob.glob(checkpoint_path + ".*"):
os.remove(filename)
logger.info("delete file: {}".format(filename))
def _load_checkpoint_idx(self, checkpoint_record: str) -> int:
"""Get the iteration number corresponding to the latest saved checkpoint.
Args:
checkpoint_path (str): the saved path of checkpoint.
Returns:
int: the latest iteration number. -1 for no checkpoint to load.
"""
if not os.path.isfile(checkpoint_record):
return -1
# Fetch the latest checkpoint index.
with open(checkpoint_record, "rt") as handle:
latest_checkpoint = handle.readlines()[-1].strip()
iteration = int(latest_checkpoint.split(":")[-1])
return iteration
def _save_checkpoint_record(self, checkpoint_dir: str, iteration: int):
"""Save the iteration number of the latest model to be checkpoint record.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number.
Returns:
None
"""
checkpoint_record_latest = os.path.join(checkpoint_dir,
"checkpoint_latest")
checkpoint_record_best = os.path.join(checkpoint_dir, "checkpoint_best")
with open(checkpoint_record_best, "w") as handle:
for i in self.best_records.keys():
handle.write("model_checkpoint_path:{}\n".format(i))
with open(checkpoint_record_latest, "w") as handle:
for i in self.latest_records:
handle.write("model_checkpoint_path:{}\n".format(i))
@mp_tools.rank_zero_only
def _save_parameters(self,
checkpoint_dir: str,
tag_or_iteration: Union[int, str],
model: paddle.nn.Layer,
optimizer: Optimizer=None,
@ -147,6 +296,3 @@ def save_parameters(checkpoint_dir: str,
with open(info_path, 'w') as fout:
data = json.dumps(infos)
fout.write(data)
if isinstance(tag_or_iteration, int):
_save_record(checkpoint_dir, tag_or_iteration)

@ -38,21 +38,23 @@ def remove_duplicates_and_blank(hyp: List[int], blank_id=0) -> List[int]:
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
# add non-blank into new_hyp
if hyp[cur] != blank_id:
new_hyp.append(hyp[cur])
# skip repeat label
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp
def insert_blank(label: np.ndarray, blank_id: int=0):
def insert_blank(label: np.ndarray, blank_id: int=0) -> np.ndarray:
"""Insert blank token between every two label token.
"abcdefg" -> "-a-b-c-d-e-f-g-"
Args:
label ([np.ndarray]): label ids, (L).
label ([np.ndarray]): label ids, List[int], (L).
blank_id (int, optional): blank id. Defaults to 0.
Returns:
@ -61,13 +63,13 @@ def insert_blank(label: np.ndarray, blank_id: int=0):
label = np.expand_dims(label, 1) #[L, 1]
blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id
label = np.concatenate([blanks, label], axis=1) #[L, 2]
label = label.reshape(-1) #[2L]
label = np.append(label, label[0]) #[2L + 1]
label = label.reshape(-1) #[2L], -l-l-l
label = np.append(label, label[0]) #[2L + 1], -l-l-l-
return label
def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
blank_id=0) -> list:
blank_id=0) -> List[int]:
"""ctc forced alignment.
https://distill.pub/2017/ctc/
@ -77,23 +79,25 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
y (paddle.Tensor): label id sequence tensor, 1d tensor (L)
blank_id (int): blank symbol index
Returns:
paddle.Tensor: best alignment result, (T).
List[int]: best alignment result, (T).
"""
y_insert_blank = insert_blank(y, blank_id)
y_insert_blank = insert_blank(y, blank_id) #(2L+1)
log_alpha = paddle.zeros(
(ctc_probs.size(0), len(y_insert_blank))) #(T, 2L+1)
log_alpha = log_alpha - float('inf') # log of zero
# TODO(Hui Zhang): zeros not support paddle.int16
state_path = (paddle.zeros(
(ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int16) - 1
) # state path
(ctc_probs.size(0), len(y_insert_blank)), dtype=paddle.int32) - 1
) # state path, Tuple((T, 2L+1))
# init start state
log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] # Sb
log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] # Snb
# TODO(Hui Zhang): VarBase.__getitem__() not support np.int64
log_alpha[0, 0] = ctc_probs[0][int(y_insert_blank[0])] # State-b, Sb
log_alpha[0, 1] = ctc_probs[0][int(y_insert_blank[1])] # State-nb, Snb
for t in range(1, ctc_probs.size(0)):
for s in range(len(y_insert_blank)):
for t in range(1, ctc_probs.size(0)): # T
for s in range(len(y_insert_blank)): # 2L+1
if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[
s] == y_insert_blank[s - 2]:
candidates = paddle.to_tensor(
@ -106,11 +110,13 @@ def forced_align(ctc_probs: paddle.Tensor, y: paddle.Tensor,
log_alpha[t - 1, s - 2],
])
prev_state = [s, s - 1, s - 2]
log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][
y_insert_blank[s]]
# TODO(Hui Zhang): VarBase.__getitem__() not support np.int64
log_alpha[t, s] = paddle.max(candidates) + ctc_probs[t][int(
y_insert_blank[s])]
state_path[t, s] = prev_state[paddle.argmax(candidates)]
state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int16)
# TODO(Hui Zhang): zeros not support paddle.int16
state_seq = -1 * paddle.ones((ctc_probs.size(0), 1), dtype=paddle.int32)
candidates = paddle.to_tensor([
log_alpha[-1, len(y_insert_blank) - 1], # Sb

@ -0,0 +1,127 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from typing import List
from typing import Text
import textgrid
def segment_alignment(alignment: List[int], blank_id=0) -> List[List[int]]:
"""segment ctc alignment ids by continuous blank and repeat label.
Args:
alignment (List[int]): ctc alignment id sequence.
e.g. [0, 0, 0, 1, 1, 1, 2, 0, 0, 3]
blank_id (int, optional): blank id. Defaults to 0.
Returns:
List[List[int]]: token align, segment aligment id sequence.
e.g. [[0, 0, 0, 1, 1, 1], [2], [0, 0, 3]]
"""
# convert alignment to a praat format, which is a doing phonetics
# by computer and helps analyzing alignment
align_segs = []
# get frames level duration for each token
start = 0
end = 0
while end < len(alignment):
while end < len(alignment) and alignment[end] == blank_id: # blank
end += 1
if end == len(alignment):
align_segs[-1].extend(alignment[start:])
break
end += 1
while end < len(alignment) and alignment[end - 1] == alignment[
end]: # repeat label
end += 1
align_segs.append(alignment[start:end])
start = end
return align_segs
def align_to_tierformat(align_segs: List[List[int]],
subsample: int,
token_dict: Dict[int, Text],
blank_id=0) -> List[Text]:
"""Generate textgrid.Interval format from alignment segmentations.
Args:
align_segs (List[List[int]]): segmented ctc alignment ids.
subsample (int): 25ms frame_length, 10ms hop_length, 1/subsample
token_dict (Dict[int, Text]): int -> str map.
Returns:
List[Text]: list of textgrid.Interval text, str(start, end, text).
"""
hop_length = 10 # ms
second_ms = 1000 # ms
frame_per_second = second_ms / hop_length # 25ms frame_length, 10ms hop_length
second_per_frame = 1.0 / frame_per_second
begin = 0
duration = 0
tierformat = []
for idx, tokens in enumerate(align_segs):
token_len = len(tokens)
token = tokens[-1]
# time duration in second
duration = token_len * subsample * second_per_frame
if idx < len(align_segs) - 1:
print(f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}")
tierformat.append(
f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}\n")
else:
for i in tokens:
if i != blank_id:
token = i
break
print(f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}")
tierformat.append(
f"{begin:.2f} {begin + duration:.2f} {token_dict[token]}\n")
begin = begin + duration
return tierformat
def generate_textgrid(maxtime: float,
intervals: List[Text],
output: Text,
name: Text='ali') -> None:
"""Create alignment textgrid file.
Args:
maxtime (float): audio duartion.
intervals (List[Text]): ctc output alignment. e.g. "start-time end-time word" per item.
output (Text): textgrid filepath.
name (Text, optional): tier or layer name. Defaults to 'ali'.
"""
# Download Praat: https://www.fon.hum.uva.nl/praat/
avg_interval = maxtime / (len(intervals) + 1)
print(f"average second/token: {avg_interval}")
margin = 0.0001
tg = textgrid.TextGrid(maxTime=maxtime)
tier = textgrid.IntervalTier(name=name, maxTime=maxtime)
i = 0
for dur in intervals:
s, e, text = dur.split()
tier.add(minTime=float(s) + margin, maxTime=float(e), mark=text)
tg.append(tier)
tg.write(output)
print("successfully generator textgrid {}.".format(output))

@ -79,3 +79,22 @@ def log_add(args: List[int]) -> float:
a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max) for a in args))
return a_max + lsp
def get_subsample(config):
"""Subsample rate from config.
Args:
config (yacs.config.CfgNode): yaml config
Returns:
int: subsample rate.
"""
input_layer = config["model"]["encoder_conf"]["input_layer"]
assert input_layer in ["conv2d", "conv2d6", "conv2d8"]
if input_layer == "conv2d":
return 4
elif input_layer == "conv2d6":
return 6
elif input_layer == "conv2d8":
return 8

@ -49,6 +49,9 @@ training:
weight_decay: 1e-06
global_grad_clip: 3.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 128

@ -4,10 +4,11 @@
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.06M | conf/conformer.yaml | spec_aug + shift | test | attention | - | 0.059858 |
| conformer | 47.06M | conf/conformer.yaml | spec_aug + shift | test | ctc_greedy_search | - | 0.062311 |
| conformer | 47.06M | conf/conformer.yaml | spec_aug + shift | test | ctc_prefix_beam_search | - | 0.062196 |
| conformer | 47.06M | conf/conformer.yaml | spec_aug + shift | test | attention_rescoring | - | 0.054694 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug + shift | test | attention | - | 0.059858 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug + shift | test | ctc_greedy_search | - | 0.062311 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug + shift | test | ctc_prefix_beam_search | - | 0.062196 |
| conformer | 47.07M | conf/conformer.yaml | spec_aug + shift | test | attention_rescoring | - | 0.054694 |
## Chunk Conformer
@ -21,6 +22,6 @@
## Transformer
| Model | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | ---|
| transformer | conf/transformer.yaml | spec_aug + shift | test | attention | - | - |
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | ---|
| transformer | - | conf/transformer.yaml | spec_aug + shift | test | attention | - | - |

@ -93,6 +93,9 @@ training:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:

@ -88,6 +88,9 @@ training:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:

@ -0,0 +1,43 @@
#! /usr/bin/env bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ngpu == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
ckpt_name=$(basename ${ckpt_prefxi})
mkdir -p exp
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/alignment.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in ctc alignment!"
exit 1
fi
exit 0

@ -30,10 +30,15 @@ fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# test ckpt avg_n
CUDA_VISIBLE_DEVICES=4 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
CUDA_VISIBLE_DEVICES=0 ./local/test.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# ctc alignment of test data
CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
CUDA_VISIBLE_DEVICES=0 ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi

@ -0,0 +1,4 @@
*.tgz
manifest.*
*.meta
aidatatang_200zh/

@ -0,0 +1,14 @@
# [Aidatatang_200zh](http://www.openslr.org/62/)
Aidatatang_200zh is a free Chinese Mandarin speech corpus provided by Beijing DataTang Technology Co., Ltd under Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 International Public License.
The contents and the corresponding descriptions of the corpus include:
* The corpus contains 200 hours of acoustic data, which is mostly mobile recorded data.
* 600 speakers from different accent areas in China are invited to participate in the recording.
* The transcription accuracy for each sentence is larger than 98%.
* Recordings are conducted in a quiet indoor environment.
* The database is divided into training set, validation set, and testing set in a ratio of 7: 1: 2.
* Detail information such as speech data coding and speaker information is preserved in the metadata file.
* Segmented transcripts are also provided.
The corpus aims to support researchers in speech recognition, machine translation, voiceprint recognition, and other speech-related fields. Therefore, the corpus is totally free for academic use.

@ -0,0 +1,151 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prepare aidatatang_200zh mandarin dataset
Download, unpack and create manifest files.
Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
"""
import argparse
import codecs
import json
import os
import soundfile
from utils.utility import download
from utils.utility import unpack
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
URL_ROOT = 'http://www.openslr.org/resources/62'
# URL_ROOT = 'https://openslr.magicdatatech.com/resources/62'
DATA_URL = URL_ROOT + '/aidatatang_200zh.tgz'
MD5_DATA = '6e0f4f39cd5f667a7ee53c397c8d0949'
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--target_dir",
default=DATA_HOME + "/aidatatang_200zh",
type=str,
help="Directory to save the dataset. (default: %(default)s)")
parser.add_argument(
"--manifest_prefix",
default="manifest",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
args = parser.parse_args()
def create_manifest(data_dir, manifest_path_prefix):
print("Creating manifest %s ..." % manifest_path_prefix)
json_lines = []
transcript_path = os.path.join(data_dir, 'transcript',
'aidatatang_200_zh_transcript.txt')
transcript_dict = {}
for line in codecs.open(transcript_path, 'r', 'utf-8'):
line = line.strip()
if line == '':
continue
audio_id, text = line.split(' ', 1)
# remove withespace, charactor text
text = ''.join(text.split())
transcript_dict[audio_id] = text
data_types = ['train', 'dev', 'test']
for dtype in data_types:
del json_lines[:]
total_sec = 0.0
total_text = 0.0
total_num = 0
audio_dir = os.path.join(data_dir, 'corpus/', dtype)
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for fname in filelist:
if not fname.endswith('.wav'):
continue
audio_path = os.path.abspath(os.path.join(subfolder, fname))
audio_id = os.path.basename(fname)[:-4]
audio_data, samplerate = soundfile.read(audio_path)
duration = float(len(audio_data) / samplerate)
text = transcript_dict[audio_id]
json_lines.append(
json.dumps(
{
'utt': audio_id,
'feat': audio_path,
'feat_shape': (duration, ), # second
'text': text,
},
ensure_ascii=False))
total_sec += duration
total_text += len(text)
total_num += 1
manifest_path = manifest_path_prefix + '.' + dtype
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines:
fout.write(line + '\n')
with open(dtype + '.meta', 'w') as f:
print(f"{dtype}:", file=f)
print(f"{total_num} utts", file=f)
print(f"{total_sec / (60*60)} h", file=f)
print(f"{total_text} text", file=f)
print(f"{total_text / total_sec} text/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f)
def prepare_dataset(url, md5sum, target_dir, manifest_path, subset):
"""Download, unpack and create manifest file."""
data_dir = os.path.join(target_dir, subset)
if not os.path.exists(data_dir):
filepath = download(url, md5sum, target_dir)
unpack(filepath, target_dir)
# unpack all audio tar files
audio_dir = os.path.join(data_dir, 'corpus')
for subfolder, dirlist, filelist in sorted(os.walk(audio_dir)):
for sub in dirlist:
print(f"unpack dir {sub}...")
for folder, _, filelist in sorted(
os.walk(os.path.join(subfolder, sub))):
for ftar in filelist:
unpack(os.path.join(folder, ftar), folder, True)
else:
print("Skip downloading and unpacking. Data already exists in %s." %
target_dir)
create_manifest(data_dir, manifest_path)
def main():
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
prepare_dataset(
url=DATA_URL,
md5sum=MD5_DATA,
target_dir=args.target_dir,
manifest_path=args.manifest_prefix,
subset='aidatatang_200zh')
print("Data download and manifest prepare done!")
if __name__ == '__main__':
main()

@ -1 +1,4 @@
data_aishell*
*.meta
manifest.*
*.tgz

@ -0,0 +1,3 @@
# [Aishell1](http://www.openslr.org/33/)
This Open Source Mandarin Speech Corpus, AISHELL-ASR0009-OS1, is 178 hours long. It is a part of AISHELL-ASR0009, of which utterance contains 11 domains, including smart home, autonomous driving, and industrial production. The whole recording was put in quiet indoor environment, using 3 different devices at the same time: high fidelity microphone (44.1kHz, 16-bit,); Android-system mobile phone (16kHz, 16-bit), iOS-system mobile phone (16kHz, 16-bit). Audios in high fidelity were re-sampled to 16kHz to build AISHELL- ASR0009-OS1. 400 speakers from different accent areas in China were invited to participate in the recording. The manual transcription accuracy rate is above 95%, through professional speech annotation and strict quality inspection. The corpus is divided into training, development and testing sets. ( This database is free for academic research, not in the commerce, if without permission. )

@ -31,7 +31,7 @@ from utils.utility import unpack
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
URL_ROOT = 'http://www.openslr.org/resources/33'
URL_ROOT = 'https://openslr.magicdatatech.com/resources/33'
# URL_ROOT = 'https://openslr.magicdatatech.com/resources/33'
DATA_URL = URL_ROOT + '/data_aishell.tgz'
MD5_DATA = '2f494334227864a8a8fec932999db9d8'
@ -60,18 +60,22 @@ def create_manifest(data_dir, manifest_path_prefix):
if line == '':
continue
audio_id, text = line.split(' ', 1)
# remove withespace
# remove withespace, charactor text
text = ''.join(text.split())
transcript_dict[audio_id] = text
data_types = ['train', 'dev', 'test']
for dtype in data_types:
del json_lines[:]
total_sec = 0.0
total_text = 0.0
total_num = 0
audio_dir = os.path.join(data_dir, 'wav', dtype)
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for fname in filelist:
audio_path = os.path.join(subfolder, fname)
audio_id = fname[:-4]
audio_path = os.path.abspath(os.path.join(subfolder, fname))
audio_id = os.path.basename(fname)[:-4]
# if no transcription for audio then skipped
if audio_id not in transcript_dict:
continue
@ -81,20 +85,30 @@ def create_manifest(data_dir, manifest_path_prefix):
json_lines.append(
json.dumps(
{
'utt':
os.path.splitext(os.path.basename(audio_path))[0],
'feat':
audio_path,
'utt': audio_id,
'feat': audio_path,
'feat_shape': (duration, ), # second
'text':
text
'text': text
},
ensure_ascii=False))
total_sec += duration
total_text += len(text)
total_num += 1
manifest_path = manifest_path_prefix + '.' + dtype
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines:
fout.write(line + '\n')
with open(dtype + '.meta', 'w') as f:
print(f"{dtype}:", file=f)
print(f"{total_num} utts", file=f)
print(f"{total_sec / (60*60)} h", file=f)
print(f"{total_text} text", file=f)
print(f"{total_text / total_sec} text/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f)
def prepare_dataset(url, md5sum, target_dir, manifest_path):
"""Download, unpack and create manifest file."""
@ -123,6 +137,8 @@ def main():
target_dir=args.target_dir,
manifest_path=args.manifest_prefix)
print("Data download and manifest prepare done!")
if __name__ == '__main__':
main()

@ -0,0 +1,3 @@
# [Aishell3](http://www.openslr.org/93/)
AISHELL-3 is a large-scale and high-fidelity multi-speaker Mandarin speech corpus which could be used to train multi-speaker Text-to-Speech (TTS) systems. The corpus contains roughly **85 hours** of emotion-neutral recordings spoken by 218 native Chinese mandarin speakers and total 88035 utterances. Their auxiliary attributes such as gender, age group and native accents are explicitly marked and provided in the corpus. Accordingly, transcripts in Chinese character-level and pinyin-level are provided along with the recordings. The word & tone transcription accuracy rate is above 98%, through professional speech annotation and strict quality inspection for tone and prosody. ( This database is free for academic research, not in the commerce, if without permission. )

@ -0,0 +1,10 @@
# [GigaSpeech](https://github.com/SpeechColab/GigaSpeech)
```
git clone https://github.com/SpeechColab/GigaSpeech.git
cd GigaSpeech
utils/gigaspeech_download.sh /disk1/audio_data/gigaspeech
toolkits/kaldi/gigaspeech_data_prep.sh --train-subset XL /disk1/audio_data/gigaspeech ../data
cd ..
```

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,14 @@
#!/bin/bash
set -e
curdir=$PWD
test -d GigaSpeech || git clone https://github.com/SpeechColab/GigaSpeech.git
pushd GigaSpeech
source env_vars.sh
./utils/download_gigaspeech.sh ${curdir}/
#toolkits/kaldi/gigaspeech_data_prep.sh --train-subset XL /disk1/audio_data/gigaspeech ../data
popd

@ -5,3 +5,5 @@ test-other
train-clean-100
train-clean-360
train-other-500
*.meta
manifest.*

@ -77,6 +77,10 @@ def create_manifest(data_dir, manifest_path):
"""
print("Creating manifest %s ..." % manifest_path)
json_lines = []
total_sec = 0.0
total_text = 0.0
total_num = 0
for subfolder, _, filelist in sorted(os.walk(data_dir)):
text_filelist = [
filename for filename in filelist if filename.endswith('trans.txt')
@ -86,7 +90,9 @@ def create_manifest(data_dir, manifest_path):
for line in io.open(text_filepath, encoding="utf8"):
segments = line.strip().split()
text = ' '.join(segments[1:]).lower()
audio_filepath = os.path.join(subfolder, segments[0] + '.flac')
audio_filepath = os.path.abspath(
os.path.join(subfolder, segments[0] + '.flac'))
audio_data, samplerate = soundfile.read(audio_filepath)
duration = float(len(audio_data)) / samplerate
json_lines.append(
@ -99,10 +105,24 @@ def create_manifest(data_dir, manifest_path):
'text':
text
}))
total_sec += duration
total_text += len(text)
total_num += 1
with codecs.open(manifest_path, 'w', 'utf-8') as out_file:
for line in json_lines:
out_file.write(line + '\n')
subset = os.path.splitext(manifest_path)[1]
with open(subset + '.meta', 'w') as f:
print(f"{subset}:", file=f)
print(f"{total_num} utts", file=f)
print(f"{total_sec / (60*60)} h", file=f)
print(f"{total_text} text", file=f)
print(f"{total_text / total_sec} text/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f)
def prepare_dataset(url, md5sum, target_dir, manifest_path):
"""Download, unpack and create summmary manifest file.

@ -0,0 +1,15 @@
# [MagicData](http://www.openslr.org/68/)
MAGICDATA Mandarin Chinese Read Speech Corpus was developed by MAGIC DATA Technology Co., Ltd. and freely published for non-commercial use.
The contents and the corresponding descriptions of the corpus include:
* The corpus contains 755 hours of speech data, which is mostly mobile recorded data.
* 1080 speakers from different accent areas in China are invited to participate in the recording.
* The sentence transcription accuracy is higher than 98%.
* Recordings are conducted in a quiet indoor environment.
* The database is divided into training set, validation set, and testing set in a ratio of 51: 1: 2.
* Detail information such as speech data coding and speaker information is preserved in the metadata file.
* The domain of recording texts is diversified, including interactive Q&A, music search, SNS messages, home command and control, etc.
* Segmented transcripts are also provided.
The corpus aims to support researchers in speech recognition, machine translation, speaker recognition, and other speech-related fields. Therefore, the corpus is totally free for academic use.

@ -2,3 +2,4 @@ dev-clean/
manifest.dev-clean
manifest.train-clean
train-clean/
*.meta

@ -58,6 +58,10 @@ def create_manifest(data_dir, manifest_path):
"""
print("Creating manifest %s ..." % manifest_path)
json_lines = []
total_sec = 0.0
total_text = 0.0
total_num = 0
for subfolder, _, filelist in sorted(os.walk(data_dir)):
text_filelist = [
filename for filename in filelist if filename.endswith('trans.txt')
@ -80,10 +84,24 @@ def create_manifest(data_dir, manifest_path):
'text':
text
}))
total_sec += duration
total_text += len(text)
total_num += 1
with codecs.open(manifest_path, 'w', 'utf-8') as out_file:
for line in json_lines:
out_file.write(line + '\n')
subset = os.path.splitext(manifest_path)[1]
with open(subset + '.meta', 'w') as f:
print(f"{subset}:", file=f)
print(f"{total_num} utts", file=f)
print(f"{total_sec / (60*60)} h", file=f)
print(f"{total_text} text", file=f)
print(f"{total_text / total_sec} text/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f)
def prepare_dataset(url, md5sum, target_dir, manifest_path):
"""Download, unpack and create summmary manifest file.

@ -0,0 +1,11 @@
# multi-cn
This is a Chinese speech recognition recipe that trains on all Chinese corpora on OpenSLR, including:
* Aidatatang (140 hours)
* Aishell (151 hours)
* MagicData (712 hours)
* Primewords (99 hours)
* ST-CMDS (110 hours)
* THCHS-30 (26 hours)
* optional AISHELL2 (~1000 hours) if available

@ -0,0 +1,6 @@
# [Primewords](http://www.openslr.org/47/)
This free Chinese Mandarin speech corpus set is released by Shanghai Primewords Information Technology Co., Ltd.
The corpus is recorded by smart mobile phones from 296 native Chinese speakers. The transcription accuracy is larger than 98%, at the confidence level of 95%. It is free for academic use.
The mapping between the transcript and utterance is given in JSON format.

@ -0,0 +1 @@
# [FreeST](http://www.openslr.org/38/)

@ -0,0 +1,6 @@
*.tgz
manifest.*
data_thchs30
resource
test-noise
*.meta

@ -0,0 +1,55 @@
# [THCHS30](http://www.openslr.org/18/)
This is the *data part* of the `THCHS30 2015` acoustic data
& scripts dataset.
The dataset is described in more detail in the paper ``THCHS-30 : A Free
Chinese Speech Corpus`` by Dong Wang, Xuewei Zhang.
A paper (if it can be called a paper) 13 years ago regarding the database:
Dong Wang, Dalei Wu, Xiaoyan Zhu, ``TCMSD: A new Chinese Continuous Speech Database``,
International Conference on Chinese Computing (ICCC'01), 2001, Singapore.
The layout of this data pack is the following:
``data``
``*.wav``
audio data
``*.wav.trn``
transcriptions
``{train,dev,test}``
contain symlinks into the ``data`` directory for both audio and
transcription files. Contents of these directories define the
train/dev/test split of the data.
``{lm_word}``
``word.3gram.lm``
trigram LM based on word
``lexicon.txt``
lexicon based on word
``{lm_phone}``
``phone.3gram.lm``
trigram LM based on phone
``lexicon.txt``
lexicon based on phone
``README.TXT``
this file
Data statistics
===============
Statistics for the data are as follows:
=========== ========== ========== ===========
**dataset** **audio** **#sents** **#words**
=========== ========== ========== ===========
train 25 10,000 198,252
dev 2:14 893 17,743
test 6:15 2,495 49,085
=========== ========== ========== ===========

@ -0,0 +1,184 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prepare THCHS-30 mandarin dataset
Download, unpack and create manifest files.
Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
"""
import argparse
import codecs
import json
import os
from multiprocessing.pool import Pool
from pathlib import Path
import soundfile
from utils.utility import download
from utils.utility import unpack
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
URL_ROOT = 'http://www.openslr.org/resources/18'
# URL_ROOT = 'https://openslr.magicdatatech.com/resources/18'
DATA_URL = URL_ROOT + '/data_thchs30.tgz'
TEST_NOISE_URL = URL_ROOT + '/test-noise.tgz'
RESOURCE_URL = URL_ROOT + '/resource.tgz'
MD5_DATA = '2d2252bde5c8429929e1841d4cb95e90'
MD5_TEST_NOISE = '7e8a985fb965b84141b68c68556c2030'
MD5_RESOURCE = 'c0b2a565b4970a0c4fe89fefbf2d97e1'
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--target_dir",
default=DATA_HOME + "/THCHS30",
type=str,
help="Directory to save the dataset. (default: %(default)s)")
parser.add_argument(
"--manifest_prefix",
default="manifest",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
args = parser.parse_args()
def read_trn(filepath):
"""read trn file.
word text in first line.
syllable text in second line.
phoneme text in third line.
Args:
filepath (str): trn path.
Returns:
list(str): (word, syllable, phone)
"""
texts = []
with open(filepath, 'r') as f:
lines = f.read().strip().split('\n')
assert len(lines) == 3, lines
# charactor text, remove withespace
texts.append(''.join(lines[0].split()))
texts.extend(lines[1:])
return texts
def resolve_symlink(filepath):
"""resolve symlink which content is norm file.
Args:
filepath (str): norm file symlink.
"""
sym_path = Path(filepath)
relative_link = sym_path.read_text().strip()
relative = Path(relative_link)
relpath = sym_path.parent / relative
return relpath.resolve()
def create_manifest(data_dir, manifest_path_prefix):
print("Creating manifest %s ..." % manifest_path_prefix)
json_lines = []
data_types = ['train', 'dev', 'test']
for dtype in data_types:
del json_lines[:]
total_sec = 0.0
total_text = 0.0
total_num = 0
audio_dir = os.path.join(data_dir, dtype)
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for fname in filelist:
file_path = os.path.join(subfolder, fname)
if file_path.endswith('.wav'):
audio_path = os.path.abspath(file_path)
text_path = resolve_symlink(audio_path + '.trn')
else:
continue
assert os.path.exists(audio_path) and os.path.exists(text_path)
audio_id = os.path.basename(audio_path)[:-4]
word_text, syllable_text, phone_text = read_trn(text_path)
audio_data, samplerate = soundfile.read(audio_path)
duration = float(len(audio_data) / samplerate)
# not dump alignment infos
json_lines.append(
json.dumps(
{
'utt': audio_id,
'feat': audio_path,
'feat_shape': (duration, ), # second
'text': word_text, # charactor
'syllable': syllable_text,
'phone': phone_text,
},
ensure_ascii=False))
total_sec += duration
total_text += len(word_text)
total_num += 1
manifest_path = manifest_path_prefix + '.' + dtype
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines:
fout.write(line + '\n')
with open(dtype + '.meta', 'w') as f:
print(f"{dtype}:", file=f)
print(f"{total_num} utts", file=f)
print(f"{total_sec / (60*60)} h", file=f)
print(f"{total_text} text", file=f)
print(f"{total_text / total_sec} text/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f)
def prepare_dataset(url, md5sum, target_dir, manifest_path, subset):
"""Download, unpack and create manifest file."""
datadir = os.path.join(target_dir, subset)
if not os.path.exists(datadir):
filepath = download(url, md5sum, target_dir)
unpack(filepath, target_dir)
else:
print("Skip downloading and unpacking. Data already exists in %s." %
target_dir)
if subset == 'data_thchs30':
create_manifest(datadir, manifest_path)
def main():
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
tasks = [
(DATA_URL, MD5_DATA, args.target_dir, args.manifest_prefix,
"data_thchs30"),
(TEST_NOISE_URL, MD5_TEST_NOISE, args.target_dir, args.manifest_prefix,
"test-noise"),
(RESOURCE_URL, MD5_RESOURCE, args.target_dir, args.manifest_prefix,
"resource"),
]
with Pool(7) as pool:
pool.starmap(prepare_dataset, tasks)
print("Data download and manifest prepare done!")
if __name__ == '__main__':
main()

@ -0,0 +1,4 @@
TIMIT.*
TIMIT
manifest.*
*.meta

@ -0,0 +1,239 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prepare Librispeech ASR datasets.
Download, unpack and create manifest files.
Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
"""
import argparse
import codecs
import json
import os
import re
import string
from pathlib import Path
import soundfile
from utils.utility import unzip
URL_ROOT = ""
MD5_DATA = "45c68037c7fdfe063a43c851f181fb2d"
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--target_dir",
default='~/.cache/paddle/dataset/speech/timit',
type=str,
help="Directory to save the dataset. (default: %(default)s)")
parser.add_argument(
"--manifest_prefix",
default="manifest",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
args = parser.parse_args()
#: A string containing Chinese punctuation marks (non-stops).
non_stops = (
# Fullwidth ASCII variants
'\uFF02\uFF03\uFF04\uFF05\uFF06\uFF07\uFF08\uFF09\uFF0A\uFF0B\uFF0C\uFF0D'
'\uFF0F\uFF1A\uFF1B\uFF1C\uFF1D\uFF1E\uFF20\uFF3B\uFF3C\uFF3D\uFF3E\uFF3F'
'\uFF40\uFF5B\uFF5C\uFF5D\uFF5E\uFF5F\uFF60'
# Halfwidth CJK punctuation
'\uFF62\uFF63\uFF64'
# CJK symbols and punctuation
'\u3000\u3001\u3003'
# CJK angle and corner brackets
'\u3008\u3009\u300A\u300B\u300C\u300D\u300E\u300F\u3010\u3011'
# CJK brackets and symbols/punctuation
'\u3014\u3015\u3016\u3017\u3018\u3019\u301A\u301B\u301C\u301D\u301E\u301F'
# Other CJK symbols
'\u3030'
# Special CJK indicators
'\u303E\u303F'
# Dashes
'\u2013\u2014'
# Quotation marks and apostrophe
'\u2018\u2019\u201B\u201C\u201D\u201E\u201F'
# General punctuation
'\u2026\u2027'
# Overscores and underscores
'\uFE4F'
# Small form variants
'\uFE51\uFE54'
# Latin punctuation
'\u00B7')
#: A string of Chinese stops.
stops = (
'\uFF01' # Fullwidth exclamation mark
'\uFF1F' # Fullwidth question mark
'\uFF61' # Halfwidth ideographic full stop
'\u3002' # Ideographic full stop
)
#: A string containing all Chinese punctuation.
punctuation = non_stops + stops
def tn(text):
# lower text
text = text.lower()
# remove punc
text = re.sub(f'[{punctuation}{string.punctuation}]', "", text)
return text
def read_txt(filepath: str) -> str:
with open(filepath, 'r') as f:
line = f.read().strip().split(maxsplit=2)[2]
return tn(line)
def read_algin(filepath: str) -> str:
"""read word or phone alignment file.
<start-sample> <end-sample> <token><newline>
Args:
filepath (str): [description]
Returns:
str: token sepearte by <space>
"""
aligns = [] # (start, end, token)
with open(filepath, 'r') as f:
for line in f:
items = line.strip().split()
# for phone: (Note: beginning and ending silence regions are marked with h#)
if items[2].strip() == 'h#':
continue
aligns.append(items)
return ' '.join([item[2] for item in aligns])
def create_manifest(data_dir, manifest_path_prefix):
"""Create a manifest json file summarizing the data set, with each line
containing the meta data (i.e. audio filepath, transcription text, audio
duration) of each audio file within the data set.
"""
print("Creating manifest %s ..." % manifest_path_prefix)
json_lines = []
utts = set()
data_types = ['TRAIN', 'TEST']
for dtype in data_types:
del json_lines[:]
total_sec = 0.0
total_text = 0.0
total_num = 0
audio_dir = Path(os.path.join(data_dir, dtype))
for fname in sorted(audio_dir.rglob('*.WAV')):
audio_path = fname.resolve() # .WAV
audio_id = audio_path.stem
# if uttid exits, then skipped
if audio_id in utts:
continue
utts.add(audio_id)
text_path = audio_path.with_suffix('.TXT')
phone_path = audio_path.with_suffix('.PHN')
word_path = audio_path.with_suffix('.WRD')
audio_data, samplerate = soundfile.read(
str(audio_path), dtype='int16')
duration = float(len(audio_data) / samplerate)
word_text = read_txt(text_path)
phone_text = read_algin(phone_path)
gender_spk = str(audio_path.parent.stem)
spk = gender_spk[1:]
gender = gender_spk[0]
utt_id = '_'.join([spk, gender, audio_id])
# not dump alignment infos
json_lines.append(
json.dumps(
{
'utt': utt_id,
'feat': str(audio_path),
'feat_shape': (duration, ), # second
'text': word_text, # word
'phone': phone_text,
'spk': spk,
'gender': gender,
},
ensure_ascii=False))
total_sec += duration
total_text += len(word_text.split())
total_num += 1
manifest_path = manifest_path_prefix + '.' + dtype.lower()
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines:
fout.write(line + '\n')
with open(dtype.lower() + '.meta', 'w') as f:
print(f"{dtype}:", file=f)
print(f"{total_num} utts", file=f)
print(f"{total_sec / (60*60)} h", file=f)
print(f"{total_text} text", file=f)
print(f"{total_text / total_sec} text/sec", file=f)
print(f"{total_sec / total_num} sec/utt", file=f)
def prepare_dataset(url, md5sum, target_dir, manifest_path):
"""Download, unpack and create summmary manifest file.
"""
filepath = os.path.join(target_dir, "TIMIT.zip")
if not os.path.exists(filepath):
print(f"Please download TIMIT.zip into {target_dir}.")
raise FileNotFoundError
if not os.path.exists(os.path.join(target_dir, "TIMIT")):
# check md5sum
assert check_md5sum(filepath, md5sum)
# unpack
unzip(filepath, target_dir)
else:
print("Skip downloading and unpacking. Data already exists in %s." %
target_dir)
# create manifest json file
create_manifest(os.path.join(target_dir, "TIMIT"), manifest_path)
def main():
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
prepare_dataset(URL_ROOT, MD5_DATA, args.target_dir, args.manifest_prefix)
print("Data download and manifest prepare done!")
if __name__ == '__main__':
main()

@ -2,7 +2,7 @@
## Deepspeech2
| Model | Params | Release | Config | Test set | Loss | WER |
| Model | Params | release | Config | Test set | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- |
| DeepSpeech2 | 42.96M | 2.1.0 | conf/deepspeech2.yaml | 15.184467315673828 | test-clean | 0.072154 |
| DeepSpeech2 | 42.96M | 2.0.0 | conf/deepspeech2.yaml | - | test-clean | 0.073973 |

@ -48,6 +48,9 @@ training:
weight_decay: 1e-06
global_grad_clip: 5.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
batch_size: 128

@ -2,17 +2,17 @@
## Conformer
| Model | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- |
| conformer | conf/conformer.yaml | spec_aug + shift | test-all | attention | test-all 6.35 | 0.057117 |
| conformer | conf/conformer.yaml | spec_aug + shift | test-clean | attention | test-all 6.35 | 0.030162 |
| conformer | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_greedy_search | test-all 6.35 | 0.037910 |
| conformer | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | test-all 6.35 | 0.037761 |
| conformer | conf/conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | test-all 6.35 | 0.032115 |
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-all | attention | 6.35 | 0.057117 |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention | 6.35 | 0.030162 |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_greedy_search | 6.35 | 0.037910 |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | ctc_prefix_beam_search | 6.35 | 0.037761 |
| conformer | 47.63 M | conf/conformer.yaml | spec_aug + shift | test-clean | attention_rescoring | 6.35 | 0.032115 |
## Transformer
| Model | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- |
| transformer | conf/transformer.yaml | spec_aug + shift | test-all | attention | test-all 6.98 | 0.066500 |
| transformer | conf/transformer.yaml | spec_aug + shift | test-clean | attention | test-all 6.98 | 0.036 |
| Model | Params | Config | Augmentation| Test set | Decode method | Loss | WER |
| --- | --- | --- | --- | --- | --- | --- | --- |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-all | attention | 6.98 | 0.066500 |
| transformer | 32.52 M | conf/transformer.yaml | spec_aug + shift | test-clean | attention | 6.98 | 0.036 |

@ -3,18 +3,20 @@ data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 4
min_input_len: 0.5
max_input_len: 20.0
min_output_len: 0.0
max_output_len: 400.0
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 16
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
@ -80,7 +82,7 @@ model:
training:
n_epoch: 120
accum_grad: 1
accum_grad: 8
global_grad_clip: 5.0
optim: adam
optim_conf:
@ -91,6 +93,9 @@ training:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:

@ -3,18 +3,20 @@ data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 64
min_input_len: 0.5 # second
max_input_len: 20.0 # second
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 64
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
@ -84,6 +86,9 @@ training:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:
@ -103,6 +108,6 @@ decoding:
# >0: for decoding, use fixed chunk size as set.
# 0: used for training, it's prohibited here.
num_decoding_left_chunks: -1 # number of left chunks for decoding. Defaults to -1.
simulate_streaming: False # simulate streaming inference. Defaults to False.
simulate_streaming: true # simulate streaming inference. Defaults to False.

@ -3,18 +3,20 @@ data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test-clean
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 16
min_input_len: 0.5 # seconds
max_input_len: 20.0 # seconds
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 16
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
@ -87,6 +89,9 @@ training:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:

@ -3,18 +3,20 @@ data:
train_manifest: data/manifest.train
dev_manifest: data/manifest.dev
test_manifest: data/manifest.test-clean
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 64
min_input_len: 0.5 # second
max_input_len: 20.0 # second
min_output_len: 0.0 # tokens
max_output_len: 400.0 # tokens
min_output_input_ratio: 0.05
max_output_input_ratio: 10.0
collator:
vocab_filepath: data/vocab.txt
unit_type: 'spm'
spm_model_prefix: 'data/bpe_unigram_5000'
mean_std_filepath: ""
augmentation_config: conf/augmentation.json
batch_size: 64
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
@ -82,6 +84,9 @@ training:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
decoding:

@ -0,0 +1,43 @@
#! /usr/bin/env bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ngpu == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
ckpt_name=$(basename ${ckpt_prefxi})
mkdir -p exp
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/alignment.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in ctc alignment!"
exit 1
fi
exit 0

@ -33,6 +33,11 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# ctc alignment of test data
CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi

@ -1,4 +1,4 @@
export MAIN_ROOT=${PWD}/../../
export MAIN_ROOT=${PWD}/../../../
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C

@ -43,12 +43,16 @@ model:
share_rnn_weights: True
training:
n_epoch: 24
n_epoch: 10
lr: 1e-5
lr_decay: 1.0
weight_decay: 1e-06
global_grad_clip: 5.0
log_interval: 1
checkpoint:
kbest_n: 3
latest_n: 2
decoding:
batch_size: 128

@ -91,6 +91,9 @@ training:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 1
checkpoint:
kbest_n: 10
latest_n: 1
decoding:

@ -84,6 +84,9 @@ training:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 1
checkpoint:
kbest_n: 10
latest_n: 1
decoding:

@ -87,6 +87,9 @@ training:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 1
checkpoint:
kbest_n: 10
latest_n: 1
decoding:

@ -84,6 +84,9 @@ training:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 1
checkpoint:
kbest_n: 10
latest_n: 1
decoding:

@ -0,0 +1,43 @@
#! /usr/bin/env bash
if [ $# != 2 ];then
echo "usage: ${0} config_path ckpt_path_prefix"
exit -1
fi
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
device=gpu
if [ ngpu == 0 ];then
device=cpu
fi
config_path=$1
ckpt_prefix=$2
ckpt_name=$(basename ${ckpt_prefxi})
mkdir -p exp
batch_size=1
output_dir=${ckpt_prefix}
mkdir -p ${output_dir}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/alignment.py \
--device ${device} \
--nproc 1 \
--config ${config_path} \
--result_file ${output_dir}/${type}.align \
--checkpoint_path ${ckpt_prefix} \
--opts decoding.batch_size ${batch_size}
if [ $? -ne 0 ]; then
echo "Failed in ctc alignment!"
exit 1
fi
exit 0

@ -34,6 +34,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# ctc alignment of test data
CUDA_VISIBLE_DEVICES=0 ./local/align.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} || exit -1
fi
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# export ckpt avg_n
CUDA_VISIBLE_DEVICES= ./local/export.sh ${conf_path} exp/${ckpt}/checkpoints/${avg_ckpt} exp/${ckpt}/checkpoints/${avg_ckpt}.jit
fi

@ -0,0 +1,77 @@
cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
project(deepspeech VERSION 0.1)
set(CMAKE_VERBOSE_MAKEFILE on)
# set std-14
set(CMAKE_CXX_STANDARD 14)
# include file
include(FetchContent)
include(ExternalProject)
# fc_patch dir
set(FETCHCONTENT_QUIET off)
get_filename_component(fc_patch "fc_patch" REALPATH BASE_DIR "${CMAKE_SOURCE_DIR}")
set(FETCHCONTENT_BASE_DIR ${fc_patch})
###############################################################################
# Option Configurations
###############################################################################
# option configurations
option(TEST_DEBUG "option for debug" OFF)
###############################################################################
# Include third party
###############################################################################
# #example for include third party
# FetchContent_Declare()
# # FetchContent_MakeAvailable was not added until CMake 3.14
# FetchContent_MakeAvailable()
# include_directories()
# ABSEIL-CPP
include(FetchContent)
FetchContent_Declare(
absl
GIT_REPOSITORY "https://github.com/abseil/abseil-cpp.git"
GIT_TAG "20210324.1"
)
FetchContent_MakeAvailable(absl)
# libsndfile
include(FetchContent)
FetchContent_Declare(
libsndfile
GIT_REPOSITORY "https://github.com/libsndfile/libsndfile.git"
GIT_TAG "1.0.31"
)
FetchContent_MakeAvailable(libsndfile)
###############################################################################
# Add local library
###############################################################################
# system lib
find_package()
# if dir have CmakeLists.txt
add_subdirectory()
# if dir do not have CmakeLists.txt
add_library(lib_name STATIC file.cc)
target_link_libraries(lib_name item0 item1)
add_dependencies(lib_name depend-target)
###############################################################################
# Library installation
###############################################################################
install()
###############################################################################
# Build binary file
###############################################################################
add_executable()
target_link_libraries()

@ -0,0 +1,2 @@
aux_source_directory(. DIR_LIB_SRCS)
add_library(decoder STATIC ${DIR_LIB_SRCS})

@ -19,7 +19,7 @@ kenlm.done:
apt-get install -y gcc-5 g++-5 && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-5 50 && update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-5 50
test -d kenlm || wget -O - https://kheafield.com/code/kenlm.tar.gz | tar xz
mkdir -p kenlm/build && cd kenlm/build && cmake .. && make -j4 && make install
cd kenlm && python setup.py install
source venv/bin/activate; cd kenlm && python setup.py install
touch kenlm.done
sox.done:

@ -14,9 +14,15 @@
import os
import tarfile
import zipfile
from typing import Text
from paddle.dataset.common import md5file
__all__ = [
"check_md5sum", "getfile_insensitive", "download_multi", "download",
"unpack", "unzip"
]
def getfile_insensitive(path):
"""Get the actual file path when given insensitive filename."""
@ -54,6 +60,19 @@ def download(url, md5sum, target_dir):
return filepath
def check_md5sum(filepath: Text, md5sum: Text) -> bool:
"""check md5sum of file.
Args:
filepath (Text): [description]
md5sum (Text): [description]
Returns:
bool: same or not.
"""
return md5file(filepath) == md5sum
def unpack(filepath, target_dir, rm_tar=False):
"""Unpack the file to the target_dir."""
print("Unpacking %s ..." % filepath)

Loading…
Cancel
Save