You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/runtime/engine/kaldi/util/kaldi-thread.h

285 lines
11 KiB

// util/kaldi-thread.h
// Copyright 2012 Johns Hopkins University (Author: Daniel Povey)
// Frantisek Skala
// 2017 University of Southern California (Author: Dogan Can)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_THREAD_KALDI_THREAD_H_
#define KALDI_THREAD_KALDI_THREAD_H_ 1
#include <thread>
#include "util/options-itf.h"
#include "util/kaldi-semaphore.h"
// This header provides convenient mechanisms for parallelization.
//
// The class MultiThreader, and the function RunMultiThreaded provide a
// mechanism to run a specified number of jobs in parellel and wait for them
// all to finish. They accept objects of some class C that derives from the
// base class MultiThreadable. C needs to define the operator () that takes
// no arguments. See ExampleClass below.
//
// The class TaskSequencer addresses a different problem typically encountered
// in Kaldi command-line programs that process a sequence of items. The items
// to be processed are coming in. They are all of different sizes, e.g.
// utterances with different numbers of frames. We would like them to be
// processed in parallel to make good use of the threads available but they
// must be output in the same order they came in. Here, we again accept objects
// of some class C with an operator () that takes no arguments. C may also have
// a destructor with side effects (typically some kind of output).
// TaskSequencer is responsible for running the jobs in parallel. It has a
// function Run() that will accept a new object of class C; this will block
// until a thread is free, at which time it will spawn a thread that starts
// running the operator () of the object. When threads are finished running,
// the objects will be deleted. TaskSequencer guarantees that the destructors
// will be called sequentially (not in parallel) and in the same order the
// objects were given to the Run() function, so that it is safe for the
// destructor to have side effects such as outputting data.
// Note: the destructor of TaskSequencer will wait for any remaining jobs that
// are still running and will call the destructors.
namespace kaldi {
extern int32 g_num_threads; // Maximum number of threads (for programs that
// use threads, which is not many of them, e.g. the SGMM update program does.
// This is 8 by default. You can change this on the command line, where
// used, with --num-threads. Programs that think they will use threads
// should register it with their ParseOptions, as something like:
// po.Register("num-threads", &g_num_threads, "Number of threads to use.");
class MultiThreadable {
// To create a function object that does part of the job, inherit from this
// class, implement a copy constructor calling the default copy constructor
// of this base class (so that thread_id_ and num_threads_ are copied to new
// instances), and finally implement the operator() that does part of the job
// based on thread_id_ and num_threads_ variables.
// Note: example implementations are in util/kaldi-thread-test.cc
public:
virtual void operator() () = 0;
// Does the main function of the class
// Subclasses have to redefine this
virtual ~MultiThreadable();
// Optional destructor. Note: the destructor of the object passed by the user
// will also be called, so watch out.
public:
// Do not redeclare thread_id_ and num_threads_ in derived classes.
int32 thread_id_; // 0 <= thread_id_ < num_threads_
int32 num_threads_;
private:
// Have additional member variables as needed.
};
class ExampleClass: public MultiThreadable {
public:
ExampleClass(int32 *foo); // Typically there will be an initializer that
// takes arguments.
ExampleClass(const ExampleClass &other); // A copy constructor is also needed;
// some example classes use the default version of this.
void operator() () {
// Does the main function of the class. This
// function will typically want to look at the values of the
// member variables thread_id_ and num_threads_, inherited
// from MultiThreadable.
}
~ExampleClass() {
// Optional destructor. Sometimes useful things happen here,
// for example summing up of certain quantities. See code
// that uses RunMultiThreaded for examples.
}
private:
// Have additional member variables as needed.
};
template<class C>
class MultiThreader {
public:
MultiThreader(int32 num_threads, const C &c_in) :
threads_(std::max<int32>(1, num_threads)),
cvec_(std::max<int32>(1, num_threads), c_in) {
if (num_threads == 0) {
// This is a special case with num_threads == 0, which behaves like with
// num_threads == 1 but without creating extra threads. This can be
// useful in GPU computations where threads cannot be used.
cvec_[0].thread_id_ = 0;
cvec_[0].num_threads_ = 1;
(cvec_[0])();
} else {
for (int32 i = 0; i < threads_.size(); i++) {
cvec_[i].thread_id_ = i;
cvec_[i].num_threads_ = threads_.size();
threads_[i] = std::thread(std::ref(cvec_[i]));
}
}
}
~MultiThreader() {
for (size_t i = 0; i < threads_.size(); i++)
if (threads_[i].joinable())
threads_[i].join();
}
private:
std::vector<std::thread> threads_;
std::vector<C> cvec_;
};
/// Here, class C should inherit from MultiThreadable. Note: if you want to
/// control the number of threads yourself, or need to do something in the main
/// thread of the program while the objects exist, just initialize the
/// MultiThreader<C> object yourself.
template<class C> void RunMultiThreaded(const C &c_in) {
MultiThreader<C> m(g_num_threads, c_in);
}
struct TaskSequencerConfig {
int32 num_threads;
int32 num_threads_total;
TaskSequencerConfig(): num_threads(1), num_threads_total(0) { }
void Register(OptionsItf *opts) {
opts->Register("num-threads", &num_threads, "Number of actively processing "
"threads to run in parallel");
opts->Register("num-threads-total", &num_threads_total, "Total number of "
"threads, including those that are waiting on other threads "
"to produce their output. Controls memory use. If <= 0, "
"defaults to --num-threads plus 20. Otherwise, must "
"be >= num-threads.");
}
};
// C should have an operator () taking no arguments, that does some kind
// of computation, and a destructor that produces some kind of output (the
// destructors will be run sequentially in the same order Run as called.
template<class C>
class TaskSequencer {
public:
TaskSequencer(const TaskSequencerConfig &config):
num_threads_(config.num_threads),
threads_avail_(config.num_threads),
tot_threads_avail_(config.num_threads_total > 0 ? config.num_threads_total :
config.num_threads + 20),
thread_list_(NULL) {
KALDI_ASSERT((config.num_threads_total <= 0 ||
config.num_threads_total >= config.num_threads) &&
"num-threads-total, if specified, must be >= num-threads");
}
/// This function takes ownership of the pointer "c", and will delete it
/// in the same sequence as Run was called on the jobs.
void Run(C *c) {
// run in main thread
if (num_threads_ == 0) {
(*c)();
delete c;
return;
}
threads_avail_.Wait(); // wait till we have a thread for computation free.
tot_threads_avail_.Wait(); // this ensures we don't have too many threads
// waiting on I/O, and consume too much memory.
// put the new RunTaskArgsList object at head of the singly
// linked list thread_list_.
thread_list_ = new RunTaskArgsList(this, c, thread_list_);
thread_list_->thread = std::thread(TaskSequencer<C>::RunTask,
thread_list_);
}
void Wait() { // You call this at the end if it's more convenient
// than waiting for the destructor. It waits for all tasks to finish.
if (thread_list_ != NULL) {
thread_list_->thread.join();
KALDI_ASSERT(thread_list_->tail == NULL); // thread would not
// have exited without setting tail to NULL.
delete thread_list_;
thread_list_ = NULL;
}
}
/// The destructor waits for the last thread to exit.
~TaskSequencer() {
Wait();
}
private:
struct RunTaskArgsList {
TaskSequencer *me; // Think of this as a "this" pointer.
C *c; // Clist element of the task we're expected
std::thread thread;
RunTaskArgsList *tail;
RunTaskArgsList(TaskSequencer *me, C *c, RunTaskArgsList *tail):
me(me), c(c), tail(tail) {}
};
// This static function gets run in the threads that we create.
static void RunTask(RunTaskArgsList *args) {
// (1) run the job.
(*(args->c))(); // call operator () on args->c, which does the computation.
args->me->threads_avail_.Signal(); // Signal that the compute-intensive
// part of the thread is done (we want to run no more than
// config_.num_threads of these.)
// (2) we want to destroy the object "c" now, by deleting it. But for
// correct sequencing (this is the whole point of this class, it
// is intended to ensure the output of the program is in correct order),
// we first wait till the previous thread, whose details will be in "tail",
// is finished.
if (args->tail != NULL) {
args->tail->thread.join();
}
delete args->c; // delete the object "c". This may cause some output,
// e.g. to a stream. We don't need to worry about concurrent access to
// the output stream, because each thread waits for the previous thread
// to be done, before doing this. So there is no risk of concurrent
// access.
args->c = NULL;
if (args->tail != NULL) {
KALDI_ASSERT(args->tail->tail == NULL); // Because we already
// did join on args->tail->thread, which means that
// thread was done, and before it exited, it would have
// deleted and set to NULL its tail (which is the next line of code).
delete args->tail;
args->tail = NULL;
}
// At this point we are exiting from the thread. Signal the
// "tot_threads_avail_" semaphore which is used to limit the total number of threads that are alive, including
// not onlhy those that are in active computation in c->operator (), but those
// that are waiting on I/O or other threads.
args->me->tot_threads_avail_.Signal();
}
int32 num_threads_; // copy of config.num_threads (since Semaphore doesn't store original count)
Semaphore threads_avail_; // Initialized to the number of threads we are
// supposed to run with; the function Run() waits on this.
Semaphore tot_threads_avail_; // We use this semaphore to ensure we don't
// consume too much memory...
RunTaskArgsList *thread_list_;
};
} // namespace kaldi
#endif // KALDI_THREAD_KALDI_THREAD_H_