refactor repo

fix decoding
pull/538/head
Hui Zhang 5 years ago
parent 49d55a865c
commit 45f73c507c

3
.gitignore vendored

@ -1,5 +1,4 @@
.DS_Store .DS_Store
*.pyc *.pyc
tools/venv tools/venv
dataset .vscode
models/*

@ -20,12 +20,13 @@ import functools
from paddle import distributed as dist from paddle import distributed as dist
from utils.utility import print_arguments from deepspeech.training.cli import default_argument_parser
from training.cli import default_argument_parser from deepspeech.utils.utility import print_arguments
from deepspeech.utils.error_rate import char_errors, word_errors
from model_utils.config import get_cfg_defaults # TODO(hui zhang): dynamic load
from model_utils.model import DeepSpeech2Tester as Tester from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from utils.error_rate import char_errors, word_errors from deepspeech.exps.deepspeech2.model import DeepSpeech2Tester as Tester
def main_sp(config, args): def main_sp(config, args):

@ -20,12 +20,12 @@ import functools
from paddle import distributed as dist from paddle import distributed as dist
from utils.utility import print_arguments from deepspeech.training.cli import default_argument_parser
from training.cli import default_argument_parser from deepspeech.utils.utility import print_arguments
from deepspeech.utils.error_rate import char_errors, word_errors
from model_utils.config import get_cfg_defaults from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from model_utils.model import DeepSpeech2Tester as Tester from deepspeech.exps.deepspeech2.model import DeepSpeech2Tester as Tester
from utils.error_rate import char_errors, word_errors
def main_sp(config, args): def main_sp(config, args):

@ -20,11 +20,11 @@ import functools
from paddle import distributed as dist from paddle import distributed as dist
from utils.utility import print_arguments from deepspeech.utils.utility import print_arguments
from training.cli import default_argument_parser from deepspeech.training.cli import default_argument_parser
from model_utils.config import get_cfg_defaults from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from model_utils.model import DeepSpeech2Trainer as Trainer from deepspeech.exps.deepspeech2.model import DeepSpeech2Trainer as Trainer
def main_sp(config, args): def main_sp(config, args):

@ -20,22 +20,21 @@ import argparse
import functools import functools
import gzip import gzip
import logging import logging
import paddle.fluid as fluid
from training.cli import default_argument_parser
from model_utils.config import get_cfg_defaults
from data_utils.dataset import SpeechCollator
from data_utils.dataset import DeepSpeech2Dataset
from data_utils.dataset import DeepSpeech2DistributedBatchSampler
from data_utils.dataset import DeepSpeech2BatchSampler
from paddle.io import DataLoader from paddle.io import DataLoader
from model_utils.network import DeepSpeech2 from deepspeech.training.cli import default_argument_parser
from model_utils.network import DeepSpeech2Loss from deepspeech.utils.error_rate import char_errors, word_errors
from deepspeech.utils.utility import add_arguments, print_arguments
from deepspeech.models.network import DeepSpeech2
from deepspeech.models.network import DeepSpeech2Loss
from utils.error_rate import char_errors, word_errors from deepspeech.exps.deepspeech2.dataset import SpeechCollator
from utils.utility import add_arguments, print_arguments from deepspeech.exps.deepspeech2.dataset import DeepSpeech2Dataset
from deepspeech.exps.deepspeech2.dataset import DeepSpeech2DistributedBatchSampler
from deepspeech.exps.deepspeech2.dataset import DeepSpeech2BatchSampler
from deepspeech.exps.deepspeech2.config import get_cfg_defaults
def tune(config, args): def tune(config, args):
@ -114,7 +113,7 @@ def tune(config, args):
return trans return trans
audio, text, audio_len, text_len = infer_data audio, text, audio_len, text_len = infer_data
_, probs, _ = model.predict(audio, audio_len) _, probs, logits_lens = model.predict(audio, audio_len)
target_transcripts = ordid2token(text, text_len) target_transcripts = ordid2token(text, text_len)
num_ins += audio.shape[0] num_ins += audio.shape[0]
@ -122,17 +121,17 @@ def tune(config, args):
for index, (alpha, beta) in enumerate(params_grid): for index, (alpha, beta) in enumerate(params_grid):
print(f"tuneing: alpha={alpha} beta={beta}") print(f"tuneing: alpha={alpha} beta={beta}")
result_transcripts = model.decode_probs( result_transcripts = model.decode_probs(
probs.numpy(), vocab_list, config.decoding.decoding_method, probs.numpy(), logits_lens, vocab_list,
config.decoding.decoding_method,
config.decoding.lang_model_path, alpha, beta, config.decoding.lang_model_path, alpha, beta,
config.decoding.beam_size, config.decoding.cutoff_prob, config.decoding.beam_size, config.decoding.cutoff_prob,
config.decoding.cutoff_top_n, config.decoding.num_proc_bsearch) config.decoding.cutoff_top_n, config.decoding.num_proc_bsearch)
for target, result in zip(target_transcripts, result_transcripts): for target, result in zip(target_transcripts, result_transcripts):
#print(f"tuneing: {target} {result}")
errors, len_ref = errors_func(target, result) errors, len_ref = errors_func(target, result)
err_sum[index] += errors err_sum[index] += errors
# accumulate the length of references of every batch # accumulate the length of references of every batchπ
# in the first iteration # in the first iteration
if args.alpha_from == alpha and args.beta_from == beta: if args.alpha_from == alpha and args.beta_from == beta:
len_refs += len_ref len_refs += len_ref
@ -148,8 +147,9 @@ def tune(config, args):
min_index = err_ave.index(err_ave_min) min_index = err_ave.index(err_ave_min)
print("\nBatch %d [%d/?], current opt (alpha, beta) = (%s, %s), " print("\nBatch %d [%d/?], current opt (alpha, beta) = (%s, %s), "
" min [%s] = %f" % " min [%s] = %f" %
(cur_batch, num_ins, "%.3f" % params_grid[min_index][0], "%.3f" % (cur_batch, num_ins, "%.3f" % params_grid[min_index][0],
params_grid[min_index][1], args.error_rate_type, err_ave_min)) "%.3f" % params_grid[min_index][1],
config.decoding.error_rate_type, err_ave_min))
cur_batch += 1 cur_batch += 1
# output WER/CER at every (alpha, beta) # output WER/CER at every (alpha, beta)

@ -56,10 +56,6 @@ _C.training = CN(
lr_decay=1.0, # learning rate decay lr_decay=1.0, # learning rate decay
weight_decay=1e-6, # the coeff of weight decay weight_decay=1e-6, # the coeff of weight decay
global_grad_clip=5.0, # the global norm clip global_grad_clip=5.0, # the global norm clip
plot_interval=1000, # plot attention and spectrogram by step
valid_interval=1000, # validation by step
save_interval=1000, # checkpoint by step
max_iteration=500000, # max iteration to train by step
n_epoch=50, # train epochs n_epoch=50, # train epochs
)) ))

@ -27,11 +27,11 @@ from paddle.io import BatchSampler
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
from paddle import distributed as dist from paddle import distributed as dist
from data_utils.utility import read_manifest from deepspeech.frontend.utility import read_manifest
from data_utils.augmentor.augmentation import AugmentationPipeline from deepspeech.frontend.augmentor.augmentation import AugmentationPipeline
from data_utils.featurizer.speech_featurizer import SpeechFeaturizer from deepspeech.frontend.featurizer.speech_featurizer import SpeechFeaturizer
from data_utils.speech import SpeechSegment from deepspeech.frontend.speech import SpeechSegment
from data_utils.normalizer import FeatureNormalizer from deepspeech.frontend.normalizer import FeatureNormalizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

@ -29,26 +29,23 @@ from paddle.io import DataLoader
from paddle.fluid.dygraph import base as imperative_base from paddle.fluid.dygraph import base as imperative_base
from paddle.fluid import layers from paddle.fluid import layers
from paddle.fluid import framework
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid import name_scope
from utils import mp_tools from deepspeech.training import Trainer
from training import Trainer from deepspeech.utils import mp_tools
from deepspeech.utils.error_rate import char_errors, word_errors, cer, wer
from model_utils.network import DeepSpeech2 from deepspeech.models.network import DeepSpeech2
from model_utils.network import DeepSpeech2Loss from deepspeech.models.network import DeepSpeech2Loss
from data_utils.dataset import SpeechCollator from deepspeech.decoders.swig_wrapper import Scorer
from data_utils.dataset import DeepSpeech2Dataset from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder
from data_utils.dataset import DeepSpeech2DistributedBatchSampler from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch
from data_utils.dataset import DeepSpeech2BatchSampler
from decoders.swig_wrapper import Scorer from deepspeech.exps.deepspeech2.dataset import SpeechCollator
from decoders.swig_wrapper import ctc_greedy_decoder from deepspeech.exps.deepspeech2.dataset import DeepSpeech2Dataset
from decoders.swig_wrapper import ctc_beam_search_decoder_batch from deepspeech.exps.deepspeech2.dataset import DeepSpeech2DistributedBatchSampler
from deepspeech.exps.deepspeech2.dataset import DeepSpeech2BatchSampler
from utils.error_rate import char_errors, word_errors, cer, wer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -161,46 +158,6 @@ class DeepSpeech2Trainer(Trainer):
self.visualizer.add_scalar("train/{}".format(k), v, self.visualizer.add_scalar("train/{}".format(k), v,
self.iteration) self.iteration)
def new_epoch(self):
"""Reset the train loader and increment ``epoch``.
"""
if self.parallel:
# batch sampler epoch start from 0
self.train_loader.batch_sampler.set_epoch(self.epoch)
self.epoch += 1
def train(self):
"""The training process.
It includes forward/backward/update and periodical validation and
saving.
"""
self.logger.info(
f"Train Total Examples: {len(self.train_loader.dataset)}")
self.new_epoch()
while self.epoch <= self.config.training.n_epoch:
try:
for batch in self.train_loader:
self.iteration += 1
self.train_batch(batch)
# if self.iteration % self.config.training.valid_interval == 0:
# self.valid()
# if self.iteration % self.config.training.save_interval == 0:
# self.save()
except Exception as e:
self.logger.error(e)
pass
self.valid()
self.save()
self.lr_scheduler.step()
self.new_epoch()
def compute_metrics(self, inputs, outputs):
pass
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
@paddle.no_grad() @paddle.no_grad()
def valid(self): def valid(self):
@ -212,7 +169,7 @@ class DeepSpeech2Trainer(Trainer):
audio, text, audio_len, text_len = batch audio, text, audio_len, text_len = batch
outputs = self.model(*batch) outputs = self.model(*batch)
loss = self.compute_losses(batch, outputs) loss = self.compute_losses(batch, outputs)
metrics = self.compute_metrics(batch, outputs) #metrics = self.compute_metrics(batch, outputs)
valid_losses['val_loss'].append(float(loss)) valid_losses['val_loss'].append(float(loss))
valid_losses['val_loss_div_batchsize'].append( valid_losses['val_loss_div_batchsize'].append(
@ -373,6 +330,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
target_transcripts = self.ordid2token(texts, texts_len) target_transcripts = self.ordid2token(texts, texts_len)
result_transcripts = self.model.decode_probs( result_transcripts = self.model.decode_probs(
probs.numpy(), probs.numpy(),
logits_len,
vocab_list, vocab_list,
decoding_method=cfg.decoding_method, decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path, lang_model_path=cfg.lang_model_path,
@ -446,15 +404,37 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
""" """
# output dir # output dir
if self.args.output: if self.args.output:
output_dir = Path(self.args.output).expanduser() / "infer" output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
else: else:
output_dir = Path( output_dir = Path(
self.args.checkpoint_path).expanduser().parent.parent / "infer" self.args.checkpoint_path).expanduser().parent.parent
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir self.output_dir = output_dir
def setup_logger(self):
"""Initialize a text logger to log the experiment.
Each process has its own text logger. The logging message is write to
the standard output and a text file named ``worker_n.log`` in the
output directory, where ``n`` means the rank of the process.
"""
format = '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
formatter = logging.Formatter(fmt=format, datefmt='%Y/%m/%d %H:%M:%S')
logger.setLevel("INFO")
# global logger
stdout = True
save_path = ""
logging.basicConfig(
level=logging.DEBUG if stdout else logging.INFO,
format=format,
datefmt='%Y/%m/%d %H:%M:%S',
filename=save_path if not stdout else None)
self.logger = logger
def setup(self): def setup(self):
"""Setup the experiment. """Setup the experiment.
""" """
@ -463,6 +443,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
self.init_parallel() self.init_parallel()
self.setup_output_dir() self.setup_output_dir()
self.setup_checkpointer()
self.setup_logger() self.setup_logger()
self.setup_dataloader() self.setup_dataloader()

@ -15,13 +15,13 @@
import json import json
import random import random
from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor from deepspeech.frontend.augmentor.volume_perturb import VolumePerturbAugmentor
from data_utils.augmentor.shift_perturb import ShiftPerturbAugmentor from deepspeech.frontend.augmentor.shift_perturb import ShiftPerturbAugmentor
from data_utils.augmentor.speed_perturb import SpeedPerturbAugmentor from deepspeech.frontend.augmentor.speed_perturb import SpeedPerturbAugmentor
from data_utils.augmentor.noise_perturb import NoisePerturbAugmentor from deepspeech.frontend.augmentor.noise_perturb import NoisePerturbAugmentor
from data_utils.augmentor.impulse_response import ImpulseResponseAugmentor from deepspeech.frontend.augmentor.impulse_response import ImpulseResponseAugmentor
from data_utils.augmentor.resample import ResampleAugmentor from deepspeech.frontend.augmentor.resample import ResampleAugmentor
from data_utils.augmentor.online_bayesian_normalization import \ from deepspeech.frontend.augmentor.online_bayesian_normalization import \
OnlineBayesianNormalizationAugmentor OnlineBayesianNormalizationAugmentor

@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
"""Contains the impulse response augmentation model.""" """Contains the impulse response augmentation model."""
from data_utils.augmentor.base import AugmentorBase from deepspeech.frontend.augmentor.base import AugmentorBase
from data_utils.utility import read_manifest from deepspeech.frontend.utility import read_manifest
from data_utils.audio import AudioSegment from deepspeech.frontend.audio import AudioSegment
class ImpulseResponseAugmentor(AugmentorBase): class ImpulseResponseAugmentor(AugmentorBase):

@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
"""Contains the noise perturb augmentation model.""" """Contains the noise perturb augmentation model."""
from data_utils.augmentor.base import AugmentorBase from deepspeech.frontend.augmentor.base import AugmentorBase
from data_utils.utility import read_manifest from deepspeech.frontend.utility import read_manifest
from data_utils.audio import AudioSegment from deepspeech.frontend.audio import AudioSegment
class NoisePerturbAugmentor(AugmentorBase): class NoisePerturbAugmentor(AugmentorBase):

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contain the online bayesian normalization augmentation model.""" """Contain the online bayesian normalization augmentation model."""
from data_utils.augmentor.base import AugmentorBase from deepspeech.frontend.augmentor.base import AugmentorBase
class OnlineBayesianNormalizationAugmentor(AugmentorBase): class OnlineBayesianNormalizationAugmentor(AugmentorBase):

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contain the resample augmentation model.""" """Contain the resample augmentation model."""
from data_utils.augmentor.base import AugmentorBase from deepspeech.frontend.augmentor.base import AugmentorBase
class ResampleAugmentor(AugmentorBase): class ResampleAugmentor(AugmentorBase):

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains the volume perturb augmentation model.""" """Contains the volume perturb augmentation model."""
from data_utils.augmentor.base import AugmentorBase from deepspeech.frontend.augmentor.base import AugmentorBase
class ShiftPerturbAugmentor(AugmentorBase): class ShiftPerturbAugmentor(AugmentorBase):

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contain the speech perturbation augmentation model.""" """Contain the speech perturbation augmentation model."""
from data_utils.augmentor.base import AugmentorBase from deepspeech.frontend.augmentor.base import AugmentorBase
class SpeedPerturbAugmentor(AugmentorBase): class SpeedPerturbAugmentor(AugmentorBase):

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains the volume perturb augmentation model.""" """Contains the volume perturb augmentation model."""
from data_utils.augmentor.base import AugmentorBase from deepspeech.frontend.augmentor.base import AugmentorBase
class VolumePerturbAugmentor(AugmentorBase): class VolumePerturbAugmentor(AugmentorBase):

@ -11,5 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from training.trainer import *

@ -14,8 +14,8 @@
"""Contains the audio featurizer class.""" """Contains the audio featurizer class."""
import numpy as np import numpy as np
from data_utils.utility import read_manifest from deepspeech.frontend.utility import read_manifest
from data_utils.audio import AudioSegment from deepspeech.frontend.audio import AudioSegment
from python_speech_features import mfcc from python_speech_features import mfcc
from python_speech_features import delta from python_speech_features import delta

@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
"""Contains the speech featurizer class.""" """Contains the speech featurizer class."""
from data_utils.featurizer.audio_featurizer import AudioFeaturizer from deepspeech.frontend.featurizer.audio_featurizer import AudioFeaturizer
from data_utils.featurizer.text_featurizer import TextFeaturizer from deepspeech.frontend.featurizer.text_featurizer import TextFeaturizer
class SpeechFeaturizer(object): class SpeechFeaturizer(object):

@ -15,8 +15,8 @@
import numpy as np import numpy as np
import random import random
from data_utils.utility import read_manifest from deepspeech.frontend.utility import read_manifest
from data_utils.audio import AudioSegment from deepspeech.frontend.audio import AudioSegment
class FeatureNormalizer(object): class FeatureNormalizer(object):

@ -14,28 +14,33 @@
"""Contains the speech segment class.""" """Contains the speech segment class."""
import numpy as np import numpy as np
from data_utils.audio import AudioSegment from deepspeech.frontend.audio import AudioSegment
class SpeechSegment(AudioSegment): class SpeechSegment(AudioSegment):
"""Speech segment abstraction, a subclass of AudioSegment, """Speech Segment with Text
with an additional transcript.
:param samples: Audio samples [num_samples x num_channels]. Args:
:type samples: ndarray.float32 AudioSegment (AudioSegment): Audio Segment
:param sample_rate: Audio sample rate.
:type sample_rate: int
:param transcript: Transcript text for the speech.
:type transript: str
:raises TypeError: If the sample data type is not float or int.
""" """
def __init__(self, samples, sample_rate, transcript): def __init__(self, samples, sample_rate, transcript):
"""Speech segment abstraction, a subclass of AudioSegment,
with an additional transcript.
Args:
samples (ndarray.float32): Audio samples [num_samples x num_channels].
sample_rate (int): Audio sample rate.
transcript (str): Transcript text for the speech.
"""
AudioSegment.__init__(self, samples, sample_rate) AudioSegment.__init__(self, samples, sample_rate)
self._transcript = transcript self._transcript = transcript
def __eq__(self, other): def __eq__(self, other):
"""Return whether two objects are equal. """Return whether two objects are equal.
Returns:
bool: True, when equal to other
""" """
if not AudioSegment.__eq__(self, other): if not AudioSegment.__eq__(self, other):
return False return False

@ -20,6 +20,7 @@ import tarfile
import time import time
from threading import Thread from threading import Thread
from multiprocessing import Process, Manager, Value from multiprocessing import Process, Manager, Value
from paddle.dataset.common import md5file from paddle.dataset.common import md5file
@ -49,51 +50,3 @@ def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0):
json_data["duration"] >= min_duration): json_data["duration"] >= min_duration):
manifest.append(json_data) manifest.append(json_data)
return manifest return manifest
def getfile_insensitive(path):
"""Get the actual file path when given insensitive filename."""
directory, filename = os.path.split(path)
directory, filename = (directory or '.'), filename.lower()
for f in os.listdir(directory):
newpath = os.path.join(directory, f)
if os.path.isfile(newpath) and f.lower() == filename:
return newpath
def download_multi(url, target_dir, extra_args):
"""Download multiple files from url to target_dir."""
if not os.path.exists(target_dir): os.makedirs(target_dir)
print("Downloading %s ..." % url)
ret_code = os.system("wget -c " + url + ' ' + extra_args + " -P " +
target_dir)
return ret_code
def download(url, md5sum, target_dir):
"""Download file from url to target_dir, and check md5sum."""
if not os.path.exists(target_dir): os.makedirs(target_dir)
filepath = os.path.join(target_dir, url.split("/")[-1])
if not (os.path.exists(filepath) and md5file(filepath) == md5sum):
print("Downloading %s ..." % url)
os.system("wget -c " + url + " -P " + target_dir)
print("\nMD5 Chesksum %s ..." % filepath)
if not md5file(filepath) == md5sum:
raise RuntimeError("MD5 checksum failed.")
else:
print("File exists, skip downloading. (%s)" % filepath)
return filepath
def unpack(filepath, target_dir, rm_tar=False):
"""Unpack the file to the target_dir."""
print("Unpacking %s ..." % filepath)
tar = tarfile.open(filepath)
tar.extractall(target_dir)
tar.close()
if rm_tar == True:
os.remove(filepath)
class XmapEndSignal():
pass

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -22,11 +22,10 @@ from paddle import nn
from paddle.nn import functional as F from paddle.nn import functional as F
from paddle.nn import initializer as I from paddle.nn import initializer as I
from utils import checkpoint from deepspeech.utils import checkpoint
from deepspeech.decoders.swig_wrapper import Scorer
from decoders.swig_wrapper import Scorer from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder
from decoders.swig_wrapper import ctc_greedy_decoder from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch
from decoders.swig_wrapper import ctc_beam_search_decoder_batch
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -661,16 +660,19 @@ class DeepSpeech2(nn.Layer):
self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path, self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path,
vocab_list) vocab_list)
def decode_probs(self, probs, vocab_list, decoding_method, lang_model_path, def decode_probs(self, probs, logits_lens, vocab_list, decoding_method,
beam_alpha, beam_beta, beam_size, cutoff_prob, lang_model_path, beam_alpha, beam_beta, beam_size,
cutoff_top_n, num_processes): cutoff_prob, cutoff_top_n, num_processes):
""" probs: activation after softmax """ """ probs: activation after softmax
logits_len: audio output lens
"""
probs_split = [probs[i, :l, :] for i, l in enumerate(logits_lens)]
if decoding_method == "ctc_greedy": if decoding_method == "ctc_greedy":
result_transcripts = self._decode_batch_greedy( result_transcripts = self._decode_batch_greedy(
probs_split=probs, vocab_list=vocab_list) probs_split=probs_split, vocab_list=vocab_list)
elif decoding_method == "ctc_beam_search": elif decoding_method == "ctc_beam_search":
result_transcripts = self._decode_batch_beam_search( result_transcripts = self._decode_batch_beam_search(
probs_split=probs, probs_split=probs_split,
beam_alpha=beam_alpha, beam_alpha=beam_alpha,
beam_beta=beam_beta, beam_beta=beam_beta,
beam_size=beam_size, beam_size=beam_size,
@ -686,12 +688,11 @@ class DeepSpeech2(nn.Layer):
def decode(self, audio, audio_len, vocab_list, decoding_method, def decode(self, audio, audio_len, vocab_list, decoding_method,
lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob, lang_model_path, beam_alpha, beam_beta, beam_size, cutoff_prob,
cutoff_top_n, num_processes): cutoff_top_n, num_processes):
_, probs, audio_lens = self.predict(audio, audio_len) _, probs, logits_lens = self.predict(audio, audio_len)
probs_split = [probs[i, :l, :] for i, l in enumerate(audio_lens)] return self.decode_probs(probs.numpy(), logits_lens, vocab_list,
return self.decode_probs(probs_split, vocab_list, decoding_method, decoding_method, lang_model_path, beam_alpha,
lang_model_path, beam_alpha, beam_beta, beam_beta, beam_size, cutoff_prob,
beam_size, cutoff_prob, cutoff_top_n, cutoff_top_n, num_processes)
num_processes)
def from_pretrained(self, checkpoint_path): def from_pretrained(self, checkpoint_path):
"""Build a model from a pretrained model. """Build a model from a pretrained model.

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,15 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from deepspeech.training.trainer import *

@ -59,7 +59,8 @@ def default_argument_parser():
parser.add_argument("--nprocs", type=int, default=1, help="number of parallel processes to use.") parser.add_argument("--nprocs", type=int, default=1, help="number of parallel processes to use.")
# overwrite extra config and default config # overwrite extra config and default config
parser.add_argument("--opts", nargs=argparse.REMAINDER, help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") #parser.add_argument("--opts", nargs=argparse.REMAINDER, help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
parser.add_argument("--opts", type=str, default=[], nargs='+', help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
# yapd: enable # yapd: enable
return parser return parser

@ -24,8 +24,8 @@ from paddle import distributed as dist
from paddle.distributed.utils import get_gpus from paddle.distributed.utils import get_gpus
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from utils import checkpoint from deepspeech.utils import checkpoint
from utils import mp_tools from deepspeech.utils import mp_tools
__all__ = ["Trainer"] __all__ = ["Trainer"]
@ -148,20 +148,6 @@ class Trainer():
checkpoint_path=self.args.checkpoint_path) checkpoint_path=self.args.checkpoint_path)
self.iteration = iteration self.iteration = iteration
def read_batch(self):
"""Read a batch from the train_loader.
Returns
-------
List[Tensor]
A batch.
"""
try:
batch = next(self.iterator)
except StopIteration:
self.new_epoch()
batch = next(self.iterator)
return batch
def new_epoch(self): def new_epoch(self):
"""Reset the train loader and increment ``epoch``. """Reset the train loader and increment ``epoch``.
""" """
@ -169,7 +155,6 @@ class Trainer():
# batch sampler epoch start from 0 # batch sampler epoch start from 0
self.train_loader.batch_sampler.set_epoch(self.epoch) self.train_loader.batch_sampler.set_epoch(self.epoch)
self.epoch += 1 self.epoch += 1
self.iterator = iter(self.train_loader)
def train(self): def train(self):
"""The training process. """The training process.
@ -177,16 +162,22 @@ class Trainer():
It includes forward/backward/update and periodical validation and It includes forward/backward/update and periodical validation and
saving. saving.
""" """
self.logger.info(
f"Train Total Examples: {len(self.train_loader.dataset)}")
self.new_epoch() self.new_epoch()
while self.iteration < self.config.training.max_iteration: while self.epoch <= self.config.training.n_epoch:
try:
for batch in self.train_loader:
self.iteration += 1 self.iteration += 1
self.train_batch() self.train_batch(batch)
except Exception as e:
self.logger.error(e)
pass
if self.iteration % self.config.training.valid_interval == 0:
self.valid() self.valid()
if self.iteration % self.config.training.save_interval == 0:
self.save() self.save()
self.lr_scheduler.step()
self.new_epoch()
def run(self): def run(self):
"""The routine of the experiment after setup. This method is intended """The routine of the experiment after setup. This method is intended

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -16,15 +16,15 @@ import os
import time import time
import logging import logging
import numpy as np import numpy as np
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
from paddle.nn import Layer from paddle.nn import Layer
from paddle.optimizer import Optimizer from paddle.optimizer import Optimizer
from utils import mp_tools from deepspeech.utils import mp_tools
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel("INFO")
__all__ = ["load_parameters", "save_parameters"] __all__ = ["load_parameters", "save_parameters"]

@ -0,0 +1,57 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains common utility functions."""
import distutils.util
def print_arguments(args):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
print("----------- Configuration Arguments -----------")
for arg, value in sorted(vars(args).items()):
print("%s: %s" % (arg, value))
print("------------------------------------------------")
def add_arguments(argname, type, default, help, argparser, **kwargs):
"""Add argparse's argument.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
add_argument("name", str, "Jonh", "User name.", parser)
args = parser.parse_args()
"""
type = distutils.util.strtobool if type == bool else type
argparser.add_argument(
"--" + argname,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)

@ -23,11 +23,12 @@ import struct
import wave import wave
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy as np import numpy as np
import _init_paths
from data_utils.data import DataGenerator from deepspeech.frontend.utility import read_manifest
from model_utils.model import DeepSpeech2Model from deepspeech.utils.utility import add_arguments, print_arguments
from data_utils.utility import read_manifest
from utils.utility import add_arguments, print_arguments from deepspeech.exps.deepspeech2.model import DeepSpeech2Model
from deepspeech.exps.deepspeech2.dataset import DataGenerator
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser) add_arg = functools.partial(add_arguments, argparser=parser)

@ -0,0 +1,2 @@
data
ckpt*

@ -34,18 +34,14 @@ training:
lr_decay: 0.83 lr_decay: 0.83
weight_decay: 1e-06 weight_decay: 1e-06
global_grad_clip: 5.0 global_grad_clip: 5.0
max_iteration: 500000
plot_interval: 1000
save_interval: 1000
valid_interval: 1000
decoding: decoding:
batch_size: 10 batch_size: 128
error_rate_type: cer error_rate_type: cer
decoding_method: ctc_beam_search decoding_method: ctc_beam_search
lang_model_path: models/lm/zh_giga.no_cna_cmn.prune01244.klm lang_model_path: data/lm/zh_giga.no_cna_cmn.prune01244.klm
alpha: 2.6 alpha: 2.6
beta: 5.0 beta: 5.0
beam_size: 300 beam_size: 300
cutoff_prob: 1.0 cutoff_prob: 0.99
cutoff_top_n: 40 cutoff_top_n: 40
num_proc_bsearch: 10 num_proc_bsearch: 10

@ -2,10 +2,13 @@
mkdir -p data mkdir -p data
TARGET_DIR=${MAIN_ROOT}/examples/dataset
mkdir -p ${TARGET_DIR}
# download data, generate manifests # download data, generate manifests
PYTHONPATH=.:$PYTHONPATH python3 local/aishell.py \ PYTHONPATH=.:$PYTHONPATH python3 ${TARGET_DIR}/aishell/aishell.py \
--manifest_prefix="data/manifest" \ --manifest_prefix="data/manifest" \
--target_dir="${MAIN_ROOT}/dataset/aishell" --target_dir="${TARGET_DIR}/aishell"
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Prepare Aishell failed. Terminated." echo "Prepare Aishell failed. Terminated."
@ -14,7 +17,7 @@ fi
# build vocabulary # build vocabulary
python3 ${MAIN_ROOT}/tools/build_vocab.py \ python3 ${MAIN_ROOT}/utils/build_vocab.py \
--count_threshold=0 \ --count_threshold=0 \
--vocab_path="data/vocab.txt" \ --vocab_path="data/vocab.txt" \
--manifest_paths "data/manifest.train" "data/manifest.dev" --manifest_paths "data/manifest.train" "data/manifest.dev"
@ -26,7 +29,7 @@ fi
# compute mean and stddev for normalizer # compute mean and stddev for normalizer
python3 ${MAIN_ROOT}/tools/compute_mean_std.py \ python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
--manifest_path="data/manifest.train" \ --manifest_path="data/manifest.train" \
--num_samples=2000 \ --num_samples=2000 \
--specgram_type="linear" \ --specgram_type="linear" \

@ -1,10 +1,13 @@
#! /usr/bin/env bash #! /usr/bin/env bash
. ../../utils/utility.sh . ${MAIN_ROOT}/utils/utility.sh
DIR=data/lm
mkdir -p ${DIR}
URL='https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm' URL='https://deepspeech.bj.bcebos.com/zh_lm/zh_giga.no_cna_cmn.prune01244.klm'
MD5="29e02312deb2e59b3c8686c7966d4fe3" MD5="29e02312deb2e59b3c8686c7966d4fe3"
TARGET=./zh_giga.no_cna_cmn.prune01244.klm TARGET=${DIR}/zh_giga.no_cna_cmn.prune01244.klm
echo "Download language model ..." echo "Download language model ..."

@ -1,10 +1,13 @@
#! /usr/bin/env bash #! /usr/bin/env bash
. ../../utils/utility.sh . ${MAIN_ROOT}/utils/utility.sh
DIR=data/pretrain
mkdir -p ${DIR}
URL='https://deepspeech.bj.bcebos.com/mandarin_models/aishell_model_fluid.tar.gz' URL='https://deepspeech.bj.bcebos.com/mandarin_models/aishell_model_fluid.tar.gz'
MD5=2bf0cc8b6d5da2a2a787b5cc36a496b5 MD5=2bf0cc8b6d5da2a2a787b5cc36a496b5
TARGET=./aishell_model_fluid.tar.gz TARGET=${DIR}/aishell_model_fluid.tar.gz
echo "Download Aishell model ..." echo "Download Aishell model ..."
@ -13,7 +16,7 @@ if [ $? -ne 0 ]; then
echo "Fail to download Aishell model!" echo "Fail to download Aishell model!"
exit 1 exit 1
fi fi
tar -zxvf $TARGET tar -zxvf $TARGET -C ${DIR}
exit 0 exit 0

@ -2,14 +2,12 @@
# download language model # download language model
cd ${MAIN_ROOT}/models/lm > /dev/null bash local/download_lm_ch.sh
bash download_lm_ch.sh
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
cd - > /dev/null
python3 -u ${MAIN_ROOT}/infer.py \ python3 -u ${BIN_DIR}/infer.py \
--device 'gpu' \ --device 'gpu' \
--nproc 1 \ --nproc 1 \
--config conf/deepspeech2.yaml \ --config conf/deepspeech2.yaml \

@ -1,22 +1,16 @@
#! /usr/bin/env bash #! /usr/bin/env bash
# download language model # download language model
cd ${MAIN_ROOT}/models/lm > /dev/null bash local/download_lm_ch.sh
bash download_lm_ch.sh
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
cd - > /dev/null
# download well-trained model # download well-trained model
cd ${MAIN_ROOT}/models/aishell > /dev/null bash local/download_model.sh
bash download_model.sh
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
cd - > /dev/null
# infer # infer
CUDA_VISIBLE_DEVICES=0 \ CUDA_VISIBLE_DEVICES=0 \
@ -35,10 +29,10 @@ python3 -u ${MAIN_ROOT}/infer.py \
--use_gpu=False \ --use_gpu=False \
--share_rnn_weights=False \ --share_rnn_weights=False \
--infer_manifest="data/manifest.test" \ --infer_manifest="data/manifest.test" \
--mean_std_path="${MAIN_ROOT}/models/aishell/mean_std.npz" \ --mean_std_path="data/pretrain/mean_std.npz" \
--vocab_path="${MAIN_ROOT}/models/aishell/vocab.txt" \ --vocab_path="data/pretrain/vocab.txt" \
--model_path="${MAIN_ROOT}/models/aishell" \ --model_path="data/pretrain" \
--lang_model_path="${MAIN_ROOT}/models/lm/zh_giga.no_cna_cmn.prune01244.klm" \ --lang_model_path="data/lm/zh_giga.no_cna_cmn.prune01244.klm" \
--decoding_method="ctc_beam_search" \ --decoding_method="ctc_beam_search" \
--error_rate_type="cer" \ --error_rate_type="cer" \
--specgram_type="linear" --specgram_type="linear"

@ -1,19 +1,16 @@
#! /usr/bin/env bash #! /usr/bin/env bash
# download language model # download language model
cd ${MAIN_ROOT}/models/lm > /dev/null bash local/download_lm_ch.sh
bash download_lm_ch.sh
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
cd - > /dev/null
python3 -u ${BIN_DIR}/test.py \
python3 -u ${MAIN_ROOT}/test.py \
--device 'gpu' \ --device 'gpu' \
--nproc 1 \ --nproc 1 \
--config conf/deepspeech2.yaml \ --config conf/deepspeech2.yaml \
--output ckpt --checkpoint_path ${1}
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Failed in evaluation!" echo "Failed in evaluation!"

@ -1,47 +1,26 @@
#! /usr/bin/env bash #! /usr/bin/env bash
# download language model # download language model
cd ${MAIN_ROOT}/models/lm > /dev/null bash local/download_lm_ch.sh
bash download_lm_ch.sh
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
cd - > /dev/null
# download well-trained model # download well-trained model
cd ${MAIN_ROOT}/models/aishell > /dev/null bash local/download_model.sh
bash download_model.sh
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
cd - > /dev/null
# evaluate model # evaluate model
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ CUDA_VISIBLE_DEVICES=0 \
python3 -u ${MAIN_ROOT}/test.py \ python3 -u ${BIN_DIR}/test.py \
--batch_size=128 \ --device 'gpu' \
--beam_size=300 \ --nproc 1 \
--num_proc_bsearch=8 \ --config conf/deepspeech2.yaml \
--num_conv_layers=2 \ --checkpoint_path data/pretrain/params.pdparams \
--num_rnn_layers=3 \ --opts data.mean_std_filepath data/pretrain/mean_std.npz \
--rnn_layer_size=1024 \ --opts data.vocab_filepath data/pretrain/vocab.txt
--alpha=2.6 \
--beta=5.0 \
--cutoff_prob=0.99 \
--cutoff_top_n=40 \
--use_gru=True \
--use_gpu=True \
--share_rnn_weights=False \
--test_manifest="data/manifest.test" \
--mean_std_path="${MAIN_ROOT}/models/aishell/mean_std.npz" \
--vocab_path="${MAIN_ROOT}/models/aishell/vocab.txt" \
--model_path="${MAIN_ROOT}/models/aishell" \
--lang_model_path="${MAIN_ROOT}/models/lm/zh_giga.no_cna_cmn.prune01244.klm" \
--decoding_method="ctc_beam_search" \
--error_rate_type="cer" \
--specgram_type="linear"
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Failed in evaluation!" echo "Failed in evaluation!"

@ -4,11 +4,14 @@
# if you wish to resume from an exists model, uncomment --init_from_pretrained_model # if you wish to resume from an exists model, uncomment --init_from_pretrained_model
export FLAGS_sync_nccl_allreduce=0 export FLAGS_sync_nccl_allreduce=0
python3 -u ${MAIN_ROOT}/train.py \ ngpu=$(echo ${CUDA_VISIBLE_DEVICES} | python -c 'import sys; a = sys.stdin.read(); print(len(a.split(",")));')
echo "using $ngpu gpus..."
python3 -u ${BIN_DIR}/train.py \
--device 'gpu' \ --device 'gpu' \
--nproc 4 \ --nproc ${ngpu} \
--config conf/deepspeech2.yaml \ --config conf/deepspeech2.yaml \
--output ckpt-${1} --output ckpt
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then

@ -1,7 +1,7 @@
#! /usr/bin/env bash #! /usr/bin/env bash
# grid-search for hyper-parameters in language model # grid-search for hyper-parameters in language model
python3 -u ${MAIN_ROOT}/tune.py \ python3 -u ${BIN_DIR}/tune.py \
--device 'gpu' \ --device 'gpu' \
--nproc 1 \ --nproc 1 \
--config conf/deepspeech2.yaml \ --config conf/deepspeech2.yaml \

@ -1 +0,0 @@
../../models

@ -8,3 +8,6 @@ export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=deepspeech2
export BIN_DIR=${MAIN_ROOT}/deepspeech/exps/${MODEL}/bin

@ -1,21 +1,16 @@
#!/bin/bash #!/bin/bash
source path.sh source path.sh
# only demos
# prepare data # prepare data
bash ./local/data.sh bash ./local/data.sh
# test pretrain model
bash ./local/test_golden.sh
# test pretain model
bash ./local/infer_golden.sh
# train model # train model
bash ./local/train.sh CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./local/train.sh
# test model # test model
bash ./local/test.sh CUDA_VISIBLE_DEVICES=0 bash ./local/test.sh ckpt/checkpoints/step-3284
# infer model # infer model
bash ./local/infer.sh CUDA_VISIBLE_DEVICES=0 bash ./local/infer.sh ckpt/checkpoints/step-3284

@ -1,11 +1,13 @@
#! /usr/bin/env bash #! /usr/bin/env bash
. ../../utils/utility.sh . ${MAIN_ROOT}/utils/utility.sh
DIR=data/lm
mkdir -p ${DIR}
URL=https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm URL=https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm
MD5="099a601759d467cd0a8523ff939819c5" MD5="099a601759d467cd0a8523ff939819c5"
TARGET=./common_crawl_00.prune01111.trie.klm TARGET=${DIR}/common_crawl_00.prune01111.trie.klm
echo "Download language model ..." echo "Download language model ..."
download $URL $MD5 $TARGET download $URL $MD5 $TARGET

@ -1,10 +1,13 @@
#! /usr/bin/env bash #! /usr/bin/env bash
. ../../utils/utility.sh . ${MAIN_ROOT}/utils/utility.sh
DIR=data/pretrain
mkdir -p ${DIR}
URL='https://deepspeech.bj.bcebos.com/demo_models/baidu_en8k_model_fluid.tar.gz' URL='https://deepspeech.bj.bcebos.com/demo_models/baidu_en8k_model_fluid.tar.gz'
MD5=7e58fbf64aa4ecf639b049792ddcf788 MD5=7e58fbf64aa4ecf639b049792ddcf788
TARGET=./baidu_en8k_model_fluid.tar.gz TARGET=${DIR}/baidu_en8k_model_fluid.tar.gz
echo "Download BaiduEn8k model ..." echo "Download BaiduEn8k model ..."

@ -6,3 +6,8 @@ export LC_ALL=C
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C # Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8 export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=deepspeech2
export BIN_DIR=${MAIN_ROOT}/deepspeech/exps/${MODEL}/bin

@ -3,22 +3,17 @@
source path.sh source path.sh
# download language model # download language model
cd ${MAIN_ROOT}/models/lm > /dev/null
bash download_lm_en.sh bash download_lm_en.sh
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
cd - > /dev/null
# download well-trained model # download well-trained model
cd ${MAIN_ROOT}/models/baidu_en8k > /dev/null
bash download_model.sh bash download_model.sh
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
cd - > /dev/null
# infer # infer
CUDA_VISIBLE_DEVICES=0 \ CUDA_VISIBLE_DEVICES=0 \
@ -37,10 +32,10 @@ python3 -u ${MAIN_ROOT}/infer.py \
--use_gpu=False \ --use_gpu=False \
--share_rnn_weights=False \ --share_rnn_weights=False \
--infer_manifest="${MAIN_ROOT}/examples/librispeech/data/manifest.test-clean" \ --infer_manifest="${MAIN_ROOT}/examples/librispeech/data/manifest.test-clean" \
--mean_std_path="${MAIN_ROOT}/models/baidu_en8k/mean_std.npz" \ --mean_std_path="data/pretrain/baidu_en8k/mean_std.npz" \
--vocab_path="${MAIN_ROOT}/models/baidu_en8k/vocab.txt" \ --vocab_path="data/pretrain/baidu_en8k/vocab.txt" \
--model_path="${MAIN_ROOT}/models/baidu_en8k" \ --model_path="data/pretrain/baidu_en8k" \
--lang_model_path="${MAIN_ROOT}/models/lm/common_crawl_00.prune01111.trie.klm" \ --lang_model_path="data/lm/common_crawl_00.prune01111.trie.klm" \
--decoding_method="ctc_beam_search" \ --decoding_method="ctc_beam_search" \
--error_rate_type="wer" \ --error_rate_type="wer" \
--specgram_type="linear" --specgram_type="linear"

@ -3,21 +3,17 @@
source path.sh source path.sh
# download language model # download language model
cd ${MAIN_ROOT}/models/lm > /dev/null
bash download_lm_en.sh bash download_lm_en.sh
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
cd - > /dev/null
# download well-trained model # download well-trained model
cd ${MAIN_ROOT}/models/baidu_en8k > /dev/null
bash download_model.sh bash download_model.sh
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
cd - > /dev/null
# evaluate model # evaluate model
@ -37,11 +33,11 @@ python3 -u ${MAIN_ROOT}/test.py \
--use_gru=True \ --use_gru=True \
--use_gpu=False \ --use_gpu=False \
--share_rnn_weights=False \ --share_rnn_weights=False \
--test_manifest="data/manifest.test-clean" \ --test_manifest="${MAIN_ROOT}/examples/librispeech/data/manifest.test-clean" \
--mean_std_path="${MAIN_ROOT}/models/baidu_en8k/mean_std.npz" \ --mean_std_path="data/pretrain/baidu_en8k/mean_std.npz" \
--vocab_path="${MAIN_ROOT}/models/baidu_en8k/vocab.txt" \ --vocab_path="data/pretrain/baidu_en8k/vocab.txt" \
--model_path="${MAIN_ROOT}/models/baidu_en8k" \ --model_path="data/pretrain/baidu_en8k" \
--lang_model_path="${MAIN_ROOT}/models/lm/common_crawl_00.prune01111.trie.klm" \ --lang_model_path="data/lm/common_crawl_00.prune01111.trie.klm" \
--decoding_method="ctc_beam_search" \ --decoding_method="ctc_beam_search" \
--error_rate_type="wer" \ --error_rate_type="wer" \
--specgram_type="linear" --specgram_type="linear"

@ -0,0 +1 @@
data_aishell*

@ -24,7 +24,7 @@ import codecs
import soundfile import soundfile
import json import json
import argparse import argparse
from data_utils.utility import download, unpack from utils.utility import download, unpack
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')

@ -29,7 +29,8 @@ import json
import io import io
from paddle.v2.dataset.common import md5file from paddle.v2.dataset.common import md5file
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') #DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
DATA_HOME = os.path.expanduser('.')
URL = "https://d4s.myairbridge.com/packagev2/AG0Y3DNBE5IWRRTV/?dlid=W19XG7T0NNHB027139H0EQ" URL = "https://d4s.myairbridge.com/packagev2/AG0Y3DNBE5IWRRTV/?dlid=W19XG7T0NNHB027139H0EQ"
MD5 = "c3ff512618d7a67d4f85566ea1bc39ec" MD5 = "c3ff512618d7a67d4f85566ea1bc39ec"

@ -0,0 +1,7 @@
dev-clean/
dev-other/
test-clean/
test-other/
train-clean-100/
train-clean-360/
train-other-500/

@ -27,10 +27,10 @@ import soundfile
import json import json
import codecs import codecs
import io import io
from data_utils.utility import download, unpack from utils.utility import download, unpack
URL_ROOT = "http://www.openslr.org/resources/12" URL_ROOT = "http://www.openslr.org/resources/12"
URL_ROOT = "https://openslr.magicdatatech.com/resources/12" #URL_ROOT = "https://openslr.magicdatatech.com/resources/12"
URL_TEST_CLEAN = URL_ROOT + "/test-clean.tar.gz" URL_TEST_CLEAN = URL_ROOT + "/test-clean.tar.gz"
URL_TEST_OTHER = URL_ROOT + "/test-other.tar.gz" URL_TEST_OTHER = URL_ROOT + "/test-other.tar.gz"
URL_DEV_CLEAN = URL_ROOT + "/dev-clean.tar.gz" URL_DEV_CLEAN = URL_ROOT + "/dev-clean.tar.gz"

@ -0,0 +1,4 @@
dev-clean/
manifest.dev-clean
manifest.train-clean
train-clean/

@ -0,0 +1,115 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prepare Librispeech ASR datasets.
Download, unpack and create manifest files.
Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
"""
import distutils.util
import os
import sys
import argparse
import soundfile
import json
import codecs
import io
from utils.utility import download, unpack
URL_ROOT = "http://www.openslr.org/resources/31"
URL_TRAIN_CLEAN = URL_ROOT + "/train-clean-5.tar.gz"
URL_DEV_CLEAN = URL_ROOT + "/dev-clean-2.tar.gz"
MD5_TRAIN_CLEAN = "5df7d4e78065366204ca6845bb08f490"
MD5_DEV_CLEAN = "6d7ab67ac6a1d2c993d050e16d61080d"
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--target_dir",
default='~/.cache/paddle/dataset/speech/libri',
type=str,
help="Directory to save the dataset. (default: %(default)s)")
parser.add_argument(
"--manifest_prefix",
default="manifest",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
args = parser.parse_args()
def create_manifest(data_dir, manifest_path):
"""Create a manifest json file summarizing the data set, with each line
containing the meta data (i.e. audio filepath, transcription text, audio
duration) of each audio file within the data set.
"""
print("Creating manifest %s ..." % manifest_path)
json_lines = []
for subfolder, _, filelist in sorted(os.walk(data_dir)):
text_filelist = [
filename for filename in filelist if filename.endswith('trans.txt')
]
if len(text_filelist) > 0:
text_filepath = os.path.join(subfolder, text_filelist[0])
for line in io.open(text_filepath, encoding="utf8"):
segments = line.strip().split()
text = ' '.join(segments[1:]).lower()
audio_filepath = os.path.join(subfolder, segments[0] + '.flac')
audio_data, samplerate = soundfile.read(audio_filepath)
duration = float(len(audio_data)) / samplerate
json_lines.append(
json.dumps({
'audio_filepath': audio_filepath,
'duration': duration,
'text': text
}))
with codecs.open(manifest_path, 'w', 'utf-8') as out_file:
for line in json_lines:
out_file.write(line + '\n')
def prepare_dataset(url, md5sum, target_dir, manifest_path):
"""Download, unpack and create summmary manifest file.
"""
if not os.path.exists(os.path.join(target_dir, "LibriSpeech")):
# download
filepath = download(url, md5sum, target_dir)
# unpack
unpack(filepath, target_dir)
else:
print("Skip downloading and unpacking. Data already exists in %s." %
target_dir)
# create manifest json file
create_manifest(target_dir, manifest_path)
def main():
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
prepare_dataset(
url=URL_TRAIN_CLEAN,
md5sum=MD5_TRAIN_CLEAN,
target_dir=os.path.join(args.target_dir, "train-clean"),
manifest_path=args.manifest_prefix + ".train-clean")
prepare_dataset(
url=URL_DEV_CLEAN,
md5sum=MD5_DEV_CLEAN,
target_dir=os.path.join(args.target_dir, "dev-clean"),
manifest_path=args.manifest_prefix + ".dev-clean")
if __name__ == '__main__':
main()

@ -0,0 +1,123 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prepare Aishell mandarin dataset
Download, unpack and create manifest files.
Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import codecs
import soundfile
import json
import argparse
from utils.utility import download, unpack
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
URL_ROOT = 'https://www.openslr.org/resources/17'
DATA_URL = URL_ROOT + '/musan.tar.gz'
MD5_DATA = ''
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--target_dir",
default=DATA_HOME + "/musan",
type=str,
help="Directory to save the dataset. (default: %(default)s)")
parser.add_argument(
"--manifest_prefix",
default="manifest",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
args = parser.parse_args()
def create_manifest(data_dir, manifest_path_prefix):
print("Creating manifest %s ..." % manifest_path_prefix)
json_lines = []
transcript_path = os.path.join(data_dir, 'transcript',
'aishell_transcript_v0.8.txt')
transcript_dict = {}
for line in codecs.open(transcript_path, 'r', 'utf-8'):
line = line.strip()
if line == '': continue
audio_id, text = line.split(' ', 1)
# remove withespace
text = ''.join(text.split())
transcript_dict[audio_id] = text
data_types = ['train', 'dev', 'test']
for type in data_types:
del json_lines[:]
audio_dir = os.path.join(data_dir, 'wav', type)
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for fname in filelist:
audio_path = os.path.join(subfolder, fname)
audio_id = fname[:-4]
# if no transcription for audio then skipped
if audio_id not in transcript_dict:
continue
audio_data, samplerate = soundfile.read(audio_path)
duration = float(len(audio_data) / samplerate)
text = transcript_dict[audio_id]
json_lines.append(
json.dumps(
{
'audio_filepath': audio_path,
'duration': duration,
'text': text
},
ensure_ascii=False))
manifest_path = manifest_path_prefix + '.' + type
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines:
fout.write(line + '\n')
def prepare_dataset(url, md5sum, target_dir, manifest_path):
"""Download, unpack and create manifest file."""
data_dir = os.path.join(target_dir, 'data_aishell')
if not os.path.exists(data_dir):
filepath = download(url, md5sum, target_dir)
unpack(filepath, target_dir)
# unpack all audio tar files
audio_dir = os.path.join(data_dir, 'wav')
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for ftar in filelist:
unpack(os.path.join(subfolder, ftar), subfolder, True)
else:
print("Skip downloading and unpacking. Data already exists in %s." %
target_dir)
create_manifest(data_dir, manifest_path)
def main():
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
prepare_dataset(
url=DATA_URL,
md5sum=MD5_DATA,
target_dir=args.target_dir,
manifest_path=args.manifest_prefix)
if __name__ == '__main__':
main()

@ -0,0 +1,123 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prepare Aishell mandarin dataset
Download, unpack and create manifest files.
Manifest file is a json-format file with each line containing the
meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import codecs
import soundfile
import json
import argparse
from data_utils.utility import download, unpack
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
URL_ROOT = 'http://www.openslr.org/resources/28'
DATA_URL = URL_ROOT + '/rirs_noises.zip'
MD5_DATA = ''
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--target_dir",
default=DATA_HOME + "/Aishell",
type=str,
help="Directory to save the dataset. (default: %(default)s)")
parser.add_argument(
"--manifest_prefix",
default="manifest",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
args = parser.parse_args()
def create_manifest(data_dir, manifest_path_prefix):
print("Creating manifest %s ..." % manifest_path_prefix)
json_lines = []
transcript_path = os.path.join(data_dir, 'transcript',
'aishell_transcript_v0.8.txt')
transcript_dict = {}
for line in codecs.open(transcript_path, 'r', 'utf-8'):
line = line.strip()
if line == '': continue
audio_id, text = line.split(' ', 1)
# remove withespace
text = ''.join(text.split())
transcript_dict[audio_id] = text
data_types = ['train', 'dev', 'test']
for type in data_types:
del json_lines[:]
audio_dir = os.path.join(data_dir, 'wav', type)
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for fname in filelist:
audio_path = os.path.join(subfolder, fname)
audio_id = fname[:-4]
# if no transcription for audio then skipped
if audio_id not in transcript_dict:
continue
audio_data, samplerate = soundfile.read(audio_path)
duration = float(len(audio_data) / samplerate)
text = transcript_dict[audio_id]
json_lines.append(
json.dumps(
{
'audio_filepath': audio_path,
'duration': duration,
'text': text
},
ensure_ascii=False))
manifest_path = manifest_path_prefix + '.' + type
with codecs.open(manifest_path, 'w', 'utf-8') as fout:
for line in json_lines:
fout.write(line + '\n')
def prepare_dataset(url, md5sum, target_dir, manifest_path):
"""Download, unpack and create manifest file."""
data_dir = os.path.join(target_dir, 'data_aishell')
if not os.path.exists(data_dir):
filepath = download(url, md5sum, target_dir)
unpack(filepath, target_dir)
# unpack all audio tar files
audio_dir = os.path.join(data_dir, 'wav')
for subfolder, _, filelist in sorted(os.walk(audio_dir)):
for ftar in filelist:
unpack(os.path.join(subfolder, ftar), subfolder, True)
else:
print("Skip downloading and unpacking. Data already exists in %s." %
target_dir)
create_manifest(data_dir, manifest_path)
def main():
if args.target_dir.startswith('~'):
args.target_dir = os.path.expanduser(args.target_dir)
prepare_dataset(
url=DATA_URL,
md5sum=MD5_DATA,
target_dir=args.target_dir,
manifest_path=args.manifest_prefix)
if __name__ == '__main__':
main()

@ -1,9 +1,12 @@
#! /usr/bin/env bash #! /usr/bin/env bash
TARGET_DIR=${MAIN_ROOT}/examples/dataset/voxforge
mkdir -p ${TARGET_DIR}
# download data, generate manifests # download data, generate manifests
PYTHONPATH=../../:$PYTHONPATH python voxforge.py \ python ${MAIN_ROOT}/examples/dataset/voxforge/voxforge.py \
--manifest_prefix='./manifest' \ --manifest_prefix="${TARGET_DIR}/manifest" \
--target_dir='./dataset/VoxForge' \ --target_dir="${TARGET_DIR}" \
--is_merge_dialect=True \ --is_merge_dialect=True \
--dialects 'american' 'british' 'australian' 'european' 'irish' 'canadian' 'indian' --dialects 'american' 'british' 'australian' 'european' 'irish' 'canadian' 'indian'

@ -27,9 +27,9 @@ import json
import argparse import argparse
import shutil import shutil
import subprocess import subprocess
from data_utils.utility import download_multi, unpack, getfile_insensitive from utils.utility import download_multi, unpack, getfile_insensitive
DATA_HOME = './dataset' DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
DATA_URL = 'http://www.repository.voxforge1.org/downloads/SpeechCorpus/Trunk/' \ DATA_URL = 'http://www.repository.voxforge1.org/downloads/SpeechCorpus/Trunk/' \
'Audio/Main/16kHz_16bit' 'Audio/Main/16kHz_16bit'

@ -0,0 +1,2 @@
data
ckpt*

@ -1,12 +1,12 @@
# https://yaml.org/type/float.html # https://yaml.org/type/float.html
data: data:
train_manifest: data/manifest.tiny train_manifest: data/manifest.train
dev_manifest: data/manifest.tiny dev_manifest: data/manifest.dev-clean
test_manifest: data/manifest.tiny test_manifest: data/manifest.test-clean
mean_std_filepath: data/mean_std.npz mean_std_filepath: data/mean_std.npz
vocab_filepath: data/vocab.txt vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.config augmentation_config: conf/augmentation.config
batch_size: 4 batch_size: 20
max_duration: 27.0 max_duration: 27.0
min_duration: 0.0 min_duration: 0.0
specgram_type: linear specgram_type: linear
@ -26,26 +26,22 @@ model:
num_conv_layers: 2 num_conv_layers: 2
num_rnn_layers: 3 num_rnn_layers: 3
rnn_layer_size: 2048 rnn_layer_size: 2048
use_gru: True use_gru: False
share_rnn_weights: True share_rnn_weights: True
training: training:
n_epoch: 20 n_epoch: 50
lr: 1e-5 lr: 5e-4
lr_decay: 0.83
weight_decay: 1e-06 weight_decay: 1e-06
global_grad_clip: 400.0 global_grad_clip: 5.0
max_iteration: 500000
plot_interval: 1000
save_interval: 1000
valid_interval: 1000
decoding: decoding:
batch_size: 128 batch_size: 128
error_rate_type: wer error_rate_type: wer
decoding_method: ctc_beam_search decoding_method: ctc_beam_search
lang_model_path: models/lm/common_crawl_00.prune01111.trie.klm lang_model_path: data/lm/common_crawl_00.prune01111.trie.klm
alpha: 2.5 alpha: 2.5
beta: 0.3 beta: 0.3
beam_size: 500 beam_size: 500
cutoff_prob: 1.0 cutoff_prob: 1.0
cutoff_top_n: 40 cutoff_top_n: 40
num_proc_bsearch: 8 num_proc_bsearch: 8

@ -1,11 +1,13 @@
#! /usr/bin/env bash #! /usr/bin/env bash
mkdir -p data mkdir -p data
TARGET_DIR=${MAIN_ROOT}/examples/dataset
mkdir -p ${TARGET_DIR}
# download data, generate manifests # download data, generate manifests
PYTHONPATH=.:$PYTHONPATH python3 local/librispeech.py \ PYTHONPATH=.:$PYTHONPATH python3 ${TARGET_DIR}/librispeech/librispeech.py \
--manifest_prefix="data/manifest" \ --manifest_prefix="data/manifest" \
--target_dir="${MAIN_ROOT}/dataset/librispeech" \ --target_dir="${TARGET_DIR}/librispeech" \
--full_download="True" --full_download="True"
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
@ -15,9 +17,8 @@ fi
cat data/manifest.train-* | shuf > data/manifest.train cat data/manifest.train-* | shuf > data/manifest.train
# build vocabulary # build vocabulary
python3 ${MAIN_ROOT}/tools/build_vocab.py \ python3 ${MAIN_ROOT}/utils/build_vocab.py \
--count_threshold=0 \ --count_threshold=0 \
--vocab_path="data/vocab.txt" \ --vocab_path="data/vocab.txt" \
--manifest_paths="data/manifest.train" --manifest_paths="data/manifest.train"
@ -27,9 +28,8 @@ if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
# compute mean and stddev for normalizer # compute mean and stddev for normalizer
python3 ${MAIN_ROOT}/tools/compute_mean_std.py \ python3 ${MAIN_ROOT}/utils/compute_mean_std.py \
--manifest_path="data/manifest.train" \ --manifest_path="data/manifest.train" \
--num_samples=2000 \ --num_samples=2000 \
--specgram_type="linear" \ --specgram_type="linear" \
@ -40,6 +40,5 @@ if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
echo "LibriSpeech Data preparation done." echo "LibriSpeech Data preparation done."
exit 0 exit 0

@ -0,0 +1,20 @@
#! /usr/bin/env bash
. ${MAIN_ROOT}/utils/utility.sh
DIR=data/lm
mkdir -p ${DIR}
URL=https://deepspeech.bj.bcebos.com/en_lm/common_crawl_00.prune01111.trie.klm
MD5="099a601759d467cd0a8523ff939819c5"
TARGET=${DIR}/common_crawl_00.prune01111.trie.klm
echo "Download language model ..."
download $URL $MD5 $TARGET
if [ $? -ne 0 ]; then
echo "Fail to download the language model!"
exit 1
fi
exit 0

@ -1,10 +1,13 @@
#! /usr/bin/env bash #! /usr/bin/env bash
. ../../utils/utility.sh . ${MAIN_ROOT}/utils/utility.sh
DIR=data/pretrain
mkdir -p ${DIR}
URL='https://deepspeech.bj.bcebos.com/eng_models/librispeech_model_fluid.tar.gz' URL='https://deepspeech.bj.bcebos.com/eng_models/librispeech_model_fluid.tar.gz'
MD5=fafb11fe57c3ecd107147056453f5348 MD5=fafb11fe57c3ecd107147056453f5348
TARGET=./librispeech_model_fluid.tar.gz TARGET=${DIR}/librispeech_model_fluid.tar.gz
echo "Download LibriSpeech model ..." echo "Download LibriSpeech model ..."
@ -13,7 +16,6 @@ if [ $? -ne 0 ]; then
echo "Fail to download LibriSpeech model!" echo "Fail to download LibriSpeech model!"
exit 1 exit 1
fi fi
tar -zxvf $TARGET tar -zxvf $TARGET -C ${DIR}
exit 0 exit 0

@ -1,43 +1,21 @@
#! /usr/bin/env bash #! /usr/bin/env bash
# download language model # download language model
cd ${MAIN_ROOT}/models/lm > /dev/null bash local/download_lm_en.sh
bash download_lm_en.sh
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
exit 1 exit 1
fi fi
cd - > /dev/null
python3 -u ${BIN_DIR}/infer.py \
--device 'gpu' \
--nproc 1 \
--config conf/deepspeech2.yaml \
--output ckpt
# infer
CUDA_VISIBLE_DEVICES=0 \
python3 -u ${MAIN_ROOT}/infer.py \
--num_samples=10 \
--beam_size=500 \
--num_proc_bsearch=8 \
--num_conv_layers=2 \
--num_rnn_layers=3 \
--rnn_layer_size=2048 \
--alpha=2.5 \
--beta=0.3 \
--cutoff_prob=1.0 \
--cutoff_top_n=40 \
--use_gru=False \
--use_gpu=True \
--share_rnn_weights=True \
--infer_manifest="data/manifest.test-clean" \
--mean_std_path="data/mean_std.npz" \
--vocab_path="data/vocab.txt" \
--model_path="checkpoints/step_final" \
--lang_model_path="${MAIN_ROOT}/models/lm/common_crawl_00.prune01111.trie.klm" \
--decoding_method="ctc_beam_search" \
--error_rate_type="wer" \
--specgram_type="linear"
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
echo "Failed in inference!" echo "Failed in inference!"
exit 1 exit 1
fi fi
exit 0 exit 0

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save