diff --git a/deepspeech/exps/deepspeech2/bin/deploy/runtime.py b/deepspeech/exps/deepspeech2/bin/deploy/runtime.py new file mode 100644 index 000000000..22dc9ad57 --- /dev/null +++ b/deepspeech/exps/deepspeech2/bin/deploy/runtime.py @@ -0,0 +1,207 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Server-end for the ASR demo.""" +import os +import time +import argparse +import functools +import paddle +import numpy as np + +from deepspeech.utils.socket_server import warm_up_test +from deepspeech.utils.socket_server import AsrTCPServer +from deepspeech.utils.socket_server import AsrRequestHandler + +from deepspeech.training.cli import default_argument_parser +from deepspeech.exps.deepspeech2.config import get_cfg_defaults + +from deepspeech.frontend.utility import read_manifest +from deepspeech.utils.utility import add_arguments, print_arguments + +from deepspeech.models.deepspeech2 import DeepSpeech2Model +from deepspeech.io.dataset import ManifestDataset + +from paddle.inference import Config +from paddle.inference import create_predictor + + +def init_predictor(args): + if args.model_dir is not None: + config = Config(args.model_dir) + else: + config = Config(args.model_file, args.params_file) + + config.enable_memory_optim() + if args.use_gpu: + config.enable_use_gpu(memory_pool_init_size_mb=1000, device_id=0) + else: + # If not specific mkldnn, you can set the blas thread. + # The thread num should not be greater than the number of cores in the CPU. + config.set_cpu_math_library_num_threads(4) + config.enable_mkldnn() + + predictor = create_predictor(config) + return predictor + + +def run(predictor, img): + # copy img data to input tensor + input_names = predictor.get_input_names() + for i, name in enumerate(input_names): + input_tensor = predictor.get_input_handle(name) + #input_tensor.reshape(img[i].shape) + #input_tensor.copy_from_cpu(img[i].copy()) + + # do the inference + predictor.run() + + results = [] + # get out data from output tensor + output_names = predictor.get_output_names() + for i, name in enumerate(output_names): + output_tensor = predictor.get_output_handle(name) + output_data = output_tensor.copy_to_cpu() + results.append(output_data) + + return results + + +def inference(config, args): + predictor = init_predictor(args) + + +def start_server(config, args): + """Start the ASR server""" + dataset = ManifestDataset( + config.data.test_manifest, + config.data.vocab_filepath, + config.data.mean_std_filepath, + augmentation_config="{}", + max_duration=config.data.max_duration, + min_duration=config.data.min_duration, + stride_ms=config.data.stride_ms, + window_ms=config.data.window_ms, + n_fft=config.data.n_fft, + max_freq=config.data.max_freq, + target_sample_rate=config.data.target_sample_rate, + specgram_type=config.data.specgram_type, + use_dB_normalization=config.data.use_dB_normalization, + target_dB=config.data.target_dB, + random_seed=config.data.random_seed, + keep_transcription_text=True) + + model = DeepSpeech2Model.from_pretrained(dataset, config, + args.checkpoint_path) + model.eval() + + # prepare ASR inference handler + def file_to_transcript(filename): + feature = dataset.process_utterance(filename, "") + audio = np.array([feature[0]]).astype('float32') #[1, D, T] + audio_len = feature[0].shape[1] + audio_len = np.array([audio_len]).astype('int64') # [1] + + result_transcript = model.decode( + paddle.to_tensor(audio), + paddle.to_tensor(audio_len), + vocab_list=dataset.vocab_list, + decoding_method=config.decoding.decoding_method, + lang_model_path=config.decoding.lang_model_path, + beam_alpha=config.decoding.alpha, + beam_beta=config.decoding.beta, + beam_size=config.decoding.beam_size, + cutoff_prob=config.decoding.cutoff_prob, + cutoff_top_n=config.decoding.cutoff_top_n, + num_processes=config.decoding.num_proc_bsearch) + return result_transcript[0] + + # warming up with utterrances sampled from Librispeech + print('-----------------------------------------------------------') + print('Warming up ...') + warm_up_test( + audio_process_handler=file_to_transcript, + manifest_path=args.warmup_manifest, + num_test_cases=3) + print('-----------------------------------------------------------') + + # start the server + server = AsrTCPServer( + server_address=(args.host_ip, args.host_port), + RequestHandlerClass=AsrRequestHandler, + speech_save_dir=args.speech_save_dir, + audio_process_handler=file_to_transcript) + print("ASR Server Started.") + server.serve_forever() + + +def main(config, args): + start_server(config, args) + + +if __name__ == "__main__": + parser = default_argument_parser() + add_arg = functools.partial(add_arguments, argparser=parser) + # yapf: disable + add_arg('host_ip', str, + 'localhost', + "Server's IP address.") + add_arg('host_port', int, 8086, "Server's IP port.") + add_arg('speech_save_dir', str, + 'demo_cache', + "Directory to save demo audios.") + add_arg('warmup_manifest', str, None, "Filepath of manifest to warm up.") + add_arg( + "--model_file", + type=str, + default="", + help="Model filename, Specify this when your model is a combined model." + ) + add_arg( + "--params_file", + type=str, + default="", + help= + "Parameter filename, Specify this when your model is a combined model." + ) + add_arg( + "--model_dir", + type=str, + default=None, + help= + "Model dir, If you load a non-combined model, specify the directory of the model." + ) + add_arg("--use_gpu", + type=bool, + default=False, + help="Whether use gpu.") + args = parser.parse_args() + print_arguments(args) + + # https://yaml.org/type/float.html + config = get_cfg_defaults() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) + + args.warmup_manifest = config.data.test_manifest + print_arguments(args) + + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) + + main(config, args) diff --git a/deepspeech/exps/deepspeech2/bin/deploy/server.py b/deepspeech/exps/deepspeech2/bin/deploy/server.py index 109beece6..6b99adc3f 100644 --- a/deepspeech/exps/deepspeech2/bin/deploy/server.py +++ b/deepspeech/exps/deepspeech2/bin/deploy/server.py @@ -52,7 +52,6 @@ def start_server(config, args): target_dB=config.data.target_dB, random_seed=config.data.random_seed, keep_transcription_text=True) - model = DeepSpeech2Model.from_pretrained(dataset, config, args.checkpoint_path) model.eval() diff --git a/deepspeech/exps/deepspeech2/bin/tune.py b/deepspeech/exps/deepspeech2/bin/tune.py index afd8d646a..1fc8dc0c1 100644 --- a/deepspeech/exps/deepspeech2/bin/tune.py +++ b/deepspeech/exps/deepspeech2/bin/tune.py @@ -107,6 +107,7 @@ def tune(config, args): target_transcripts = ordid2token(text, text_len) num_ins += audio.shape[0] + # model infer eouts, eouts_len = model.encoder(audio, audio_len) probs = model.decoder.probs(eouts) diff --git a/deepspeech/io/sampler.py b/deepspeech/io/sampler.py index a0cc469ff..5bc49dad8 100644 --- a/deepspeech/io/sampler.py +++ b/deepspeech/io/sampler.py @@ -33,7 +33,7 @@ __all__ = [ ] -def _batch_shuffle(indices, batch_size, clipped=False): +def _batch_shuffle(indices, batch_size, epoch, clipped=False): """Put similarly-sized instances into minibatches for better efficiency and make a batch-wise shuffle. @@ -54,7 +54,7 @@ def _batch_shuffle(indices, batch_size, clipped=False): :return: Batch shuffled mainifest. :rtype: list """ - rng = np.random.RandomState(self.epoch) + rng = np.random.RandomState(epoch) shift_len = rng.randint(0, batch_size - 1) batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size)) rng.shuffle(batch_indices) @@ -120,7 +120,10 @@ class SortagradDistributedBatchSampler(DistributedBatchSampler): # since diff batch examlpe length in batches case instability loss in diff rank, # e.g. rank0 maxlength 20, rank3 maxlength 1000 indices = _batch_shuffle( - indices, self.batch_size * self.nranks, clipped=False) + indices, + self.batch_size * self.nranks, + self.epoch, + clipped=False) elif self._shuffle_method == "instance_shuffle": np.random.RandomState(self.epoch).shuffle(indices) else: @@ -221,7 +224,7 @@ class SortagradBatchSampler(BatchSampler): logger.info(f'dataset shuffle! epoch {self.epoch}') if self._shuffle_method == "batch_shuffle": indices = _batch_shuffle( - indices, self.batch_size, clipped=False) + indices, self.batch_size, self.epoch, clipped=False) elif self._shuffle_method == "instance_shuffle": np.random.RandomState(self.epoch).shuffle(indices) else: diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index a775aa6b5..ebaed256b 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -278,6 +278,10 @@ class Trainer(): # handler.setFormatter(formatter) # logger.addHandler(handler) + # stop propagate for propagating may print + # log multiple times + logger.propagate = False + # global logger stdout = False save_path = log_file