UFJF - Machine Learning Toolkit  0.51.8
ThreadPool.hpp
1 #pragma once
2 
15 #define THREAD_POOL_VERSION "v2.0.0 (2021-08-14)"
16 
17 #include <atomic> // std::atomic
18 #include <chrono> // std::chrono
19 #include <cstdint> // std::int_fast64_t, std::uint_fast32_t
20 #include <functional> // std::function
21 #include <future> // std::future, std::promise
22 #include <iostream> // std::cout, std::ostream
23 #include <memory> // std::shared_ptr, std::unique_ptr
24 #include <mutex> // std::mutex, std::scoped_lock
25 #include <queue> // std::queue
26 #include <thread> // std::this_thread, std::thread
27 #include <type_traits> // std::common_type_t, std::decay_t, std::enable_if_t, std::is_void_v, std::invoke_result_t
28 #include <utility> // std::move
29 
30 
31 namespace mltk {
32 // ============================================================================================= //
33 // Begin class thread_pool //
34 
39 {
40  typedef std::uint_fast32_t ui32;
41  typedef std::uint_fast64_t ui64;
42 
43 public:
44  // ============================
45  // Constructors and destructors
46  // ============================
47 
53  ThreadPool(const ui32 &_thread_count = std::thread::hardware_concurrency())
54  : thread_count(_thread_count ? _thread_count : std::thread::hardware_concurrency()), threads(new std::thread[_thread_count ? _thread_count : std::thread::hardware_concurrency()])
55  {
56  create_threads();
57  }
58 
63  {
65  running = false;
66  destroy_threads();
67  }
68 
69  // =======================
70  // Public member functions
71  // =======================
72 
78  ui64 get_tasks_queued() const
79  {
80  const std::scoped_lock lock(queue_mutex);
81  return tasks.size();
82  }
83 
89  ui32 get_tasks_running() const
90  {
91  return tasks_total - (ui32)get_tasks_queued();
92  }
93 
99  ui32 get_tasks_total() const
100  {
101  return tasks_total;
102  }
103 
109  ui32 get_thread_count() const
110  {
111  return thread_count;
112  }
113 
125  template <typename T1, typename T2, typename F>
126  void parallelize_loop(const T1 &first_index, const T2 &index_after_last, const F &loop, ui32 num_blocks = 0)
127  {
128  typedef std::common_type_t<T1, T2> T;
129  T the_first_index = (T)first_index;
130  T last_index = (T)index_after_last;
131  if (the_first_index == last_index)
132  return;
133  if (last_index < the_first_index)
134  {
135  T temp = last_index;
136  last_index = the_first_index;
137  the_first_index = temp;
138  }
139  last_index--;
140  if (num_blocks == 0)
141  num_blocks = thread_count;
142  ui64 total_size = (ui64)(last_index - the_first_index + 1);
143  ui64 block_size = (ui64)(total_size / num_blocks);
144  if (block_size == 0)
145  {
146  block_size = 1;
147  num_blocks = (ui32)total_size > 1 ? (ui32)total_size : 1;
148  }
149  std::atomic<ui32> blocks_running = 0;
150  for (ui32 t = 0; t < num_blocks; t++)
151  {
152  T start = ((T)(t * block_size) + the_first_index);
153  T end = (t == num_blocks - 1) ? last_index + 1 : ((T)((t + 1) * block_size) + the_first_index);
154  blocks_running++;
155  push_task([start, end, &loop, &blocks_running]
156  {
157  loop(start, end);
158  blocks_running--;
159  });
160  }
161  while (blocks_running != 0)
162  {
163  sleep_or_yield();
164  }
165  }
166 
173  template <typename F>
174  void push_task(const F &task)
175  {
176  tasks_total++;
177  {
178  const std::scoped_lock lock(queue_mutex);
179  tasks.push(std::function<void()>(task));
180  }
181  }
182 
192  template <typename F, typename... A>
193  void push_task(const F &task, const A &...args)
194  {
195  push_task([task, args...]
196  { task(args...); });
197  }
198 
204  void reset(const ui32 &_thread_count = std::thread::hardware_concurrency())
205  {
206  bool was_paused = paused;
207  paused = true;
208  wait_for_tasks();
209  running = false;
210  destroy_threads();
211  thread_count = _thread_count ? _thread_count : std::thread::hardware_concurrency();
212  threads.reset(new std::thread[thread_count]);
213  paused = was_paused;
214  running = true;
215  create_threads();
216  }
217 
227  template <typename F, typename... A, typename = std::enable_if_t<std::is_void_v<std::invoke_result_t<std::decay_t<F>, std::decay_t<A>...>>>>
228  std::future<bool> submit(const F &task, const A &...args)
229  {
230  std::shared_ptr<std::promise<bool>> task_promise(new std::promise<bool>);
231  std::future<bool> future = task_promise->get_future();
232  push_task([task, args..., task_promise]
233  {
234  try
235  {
236  task(args...);
237  task_promise->set_value(true);
238  }
239  catch (...)
240  {
241  try
242  {
243  task_promise->set_exception(std::current_exception());
244  }
245  catch (...)
246  {
247  }
248  }
249  });
250  return future;
251  }
252 
263  template <typename F, typename... A, typename R = std::invoke_result_t<std::decay_t<F>, std::decay_t<A>...>, typename = std::enable_if_t<!std::is_void_v<R>>>
264  std::future<R> submit(const F &task, const A &...args)
265  {
266  std::shared_ptr<std::promise<R>> task_promise(new std::promise<R>);
267  std::future<R> future = task_promise->get_future();
268  push_task([task, args..., task_promise]
269  {
270  try
271  {
272  task_promise->set_value(task(args...));
273  }
274  catch (...)
275  {
276  try
277  {
278  task_promise->set_exception(std::current_exception());
279  }
280  catch (...)
281  {
282  }
283  }
284  });
285  return future;
286  }
287 
292  {
293  while (true)
294  {
295  if (!paused)
296  {
297  if (tasks_total == 0)
298  break;
299  }
300  else
301  {
302  if (get_tasks_running() == 0)
303  break;
304  }
305  sleep_or_yield();
306  }
307  }
308 
309  // ===========
310  // Public data
311  // ===========
312 
316  std::atomic<bool> paused = false;
317 
321  ui32 sleep_duration = 1000;
322 
323 private:
324  // ========================
325  // Private member functions
326  // ========================
327 
331  void create_threads()
332  {
333  for (ui32 i = 0; i < thread_count; i++)
334  {
335  threads[i] = std::thread(&ThreadPool::worker, this);
336  }
337  }
338 
342  void destroy_threads()
343  {
344  for (ui32 i = 0; i < thread_count; i++)
345  {
346  threads[i].join();
347  }
348  }
349 
356  bool pop_task(std::function<void()> &task)
357  {
358  const std::scoped_lock lock(queue_mutex);
359  if (tasks.empty())
360  return false;
361  else
362  {
363  task = std::move(tasks.front());
364  tasks.pop();
365  return true;
366  }
367  }
368 
373  void sleep_or_yield()
374  {
375  if (sleep_duration)
376  std::this_thread::sleep_for(std::chrono::microseconds(sleep_duration));
377  else
378  std::this_thread::yield();
379  }
380 
384  void worker()
385  {
386  while (running)
387  {
388  std::function<void()> task;
389  if (!paused && pop_task(task))
390  {
391  task();
392  tasks_total--;
393  }
394  else
395  {
396  sleep_or_yield();
397  }
398  }
399  }
400 
401  // ============
402  // Private data
403  // ============
404 
408  mutable std::mutex queue_mutex = {};
409 
413  std::atomic<bool> running = true;
414 
418  std::queue<std::function<void()>> tasks = {};
419 
423  ui32 thread_count;
424 
428  std::unique_ptr<std::thread[]> threads;
429 
433  std::atomic<ui32> tasks_total = 0;
434 };
435 
436 // End class thread_pool //
437 // ============================================================================================= //
438 
439 // ============================================================================================= //
440 // Begin class synced_stream //
441 
446 {
447 public:
453  synced_stream(std::ostream &_out_stream = std::cout)
454  : out_stream(_out_stream){};
455 
462  template <typename... T>
463  void print(const T &...items)
464  {
465  const std::scoped_lock lock(stream_mutex);
466  (out_stream << ... << items);
467  }
468 
475  template <typename... T>
476  void println(const T &...items)
477  {
478  print(items..., '\n');
479  }
480 
481 private:
485  mutable std::mutex stream_mutex = {};
486 
490  std::ostream &out_stream;
491 };
492 
493 // End class synced_stream //
494 // ============================================================================================= //
495 
496 // ============================================================================================= //
497 // Begin class timer //
498 
502 class timer
503 {
504  typedef std::int_fast64_t i64;
505 
506 public:
510  void start()
511  {
512  start_time = std::chrono::steady_clock::now();
513  }
514 
518  void stop()
519  {
520  elapsed_time = std::chrono::steady_clock::now() - start_time;
521  }
522 
528  i64 ms() const
529  {
530  return (std::chrono::duration_cast<std::chrono::milliseconds>(elapsed_time)).count();
531  }
532 
533 private:
537  std::chrono::time_point<std::chrono::steady_clock> start_time = std::chrono::steady_clock::now();
538 
542  std::chrono::duration<double> elapsed_time = std::chrono::duration<double>::zero();
543 };
544 
545 // End class timer //
546 // ============================================================================================= //
547 
548 } // namespace mltk
A C++17 thread pool class. The user submits tasks to be executed into a queue. Whenever a thread beco...
Definition: ThreadPool.hpp:39
~ThreadPool()
Destruct the thread pool. Waits for all tasks to complete, then destroys all threads....
Definition: ThreadPool.hpp:62
std::atomic< bool > paused
An atomic variable indicating to the workers to pause. When set to true, the workers temporarily stop...
Definition: ThreadPool.hpp:316
std::future< R > submit(const F &task, const A &...args)
Submit a function with zero or more arguments and a return value into the task queue,...
Definition: ThreadPool.hpp:264
ui32 get_thread_count() const
Get the number of threads in the pool.
Definition: ThreadPool.hpp:109
void wait_for_tasks()
Wait for tasks to be completed. Normally, this function waits for all tasks, both those that are curr...
Definition: ThreadPool.hpp:291
std::future< bool > submit(const F &task, const A &...args)
Submit a function with zero or more arguments and no return value into the task queue,...
Definition: ThreadPool.hpp:228
ui64 get_tasks_queued() const
Get the number of tasks currently waiting in the queue to be executed by the threads.
Definition: ThreadPool.hpp:78
void push_task(const F &task)
Push a function with no arguments or return value into the task queue.
Definition: ThreadPool.hpp:174
ui32 get_tasks_total() const
Get the total number of unfinished tasks - either still in the queue, or running in a thread.
Definition: ThreadPool.hpp:99
ThreadPool(const ui32 &_thread_count=std::thread::hardware_concurrency())
Construct a new thread pool.
Definition: ThreadPool.hpp:53
void parallelize_loop(const T1 &first_index, const T2 &index_after_last, const F &loop, ui32 num_blocks=0)
Parallelize a loop by splitting it into blocks, submitting each block separately to the thread pool,...
Definition: ThreadPool.hpp:126
ui32 get_tasks_running() const
Get the number of tasks currently being executed by the threads.
Definition: ThreadPool.hpp:89
void reset(const ui32 &_thread_count=std::thread::hardware_concurrency())
Reset the number of threads in the pool. Waits for all currently running tasks to be completed,...
Definition: ThreadPool.hpp:204
void push_task(const F &task, const A &...args)
Push a function with arguments, but no return value, into the task queue.
Definition: ThreadPool.hpp:193
A helper class to synchronize printing to an output stream by different threads.
Definition: ThreadPool.hpp:446
synced_stream(std::ostream &_out_stream=std::cout)
Construct a new synced stream.
Definition: ThreadPool.hpp:453
void print(const T &...items)
Print any number of items into the output stream. Ensures that no other threads print to this stream ...
Definition: ThreadPool.hpp:463
void println(const T &...items)
Print any number of items into the output stream, followed by a newline character....
Definition: ThreadPool.hpp:476
A helper class to measure execution time for benchmarking purposes.
Definition: ThreadPool.hpp:503
void start()
Start (or restart) measuring time.
Definition: ThreadPool.hpp:510
void stop()
Stop measuring time and store the elapsed time since start().
Definition: ThreadPool.hpp:518
i64 ms() const
Get the number of milliseconds that have elapsed between start() and stop().
Definition: ThreadPool.hpp:528
UFJF-MLTK main namespace for core functionalities.
Definition: classifier/Classifier.hpp:11