UFJF - Machine Learning Toolkit  0.51.8
DistanceMatrix.hpp
1 #pragma once
2 
3 #include <thread>
4 
5 #include "ThreadPool.hpp"
6 #include "DistanceMetric.hpp"
7 
8 namespace mltk::metrics::dist {
9  class BaseMatrix {
10  protected:
11  mltk::Point<double>::Matrix rows;
12  bool isDiagonal{false};
13  size_t threads{ std::thread::hardware_concurrency() };
14 
15  public:
16  BaseMatrix() = default;
17 
18  explicit BaseMatrix(mltk::Data<double> &data, const bool isDiagonal = false, const size_t threads = std::thread::hardware_concurrency()) {
19  this->threads = threads;
20  this->isDiagonal = isDiagonal;
21  this->rows = mltk::Point<double>::Matrix(data.size());
22  }
23 
24  bool isDiagonalMatrix() const {return this->isDiagonal;}
25 
26  size_t size() const {return this->rows.size();}
27 
28  mltk::Point<double> operator[](size_t i) const {return this->rows[i];}
29 
30  mltk::Point<double> & operator[](size_t i) {return this->rows[i];}
31  };
32 
33  template<typename DistanceFunc = metrics::dist::Euclidean<double> >
34  class DistanceMatrix: public BaseMatrix {
35  private:
36  DistanceFunc dist_function{};
37 
38  void compute(const mltk::Data<double> &data) {
39  mltk::ThreadPool pool(threads);
40 
41  auto loop = [data, this](const int a, const int b) {
42  for(size_t idx = a; idx < b; idx++) {
43  this->rows[idx] = mltk::Point<double>((isDiagonal) ? idx+1 : data.size());
44 
45  for(size_t j = 0; j < idx; j++){
46  this->rows[data(idx).Id()-1][data(j).Id()-1] = this->dist_function(data(idx), data(j));
47  }
48 
49  if(isDiagonal) continue;
50 
51  for(size_t j = idx+1; j < data.size(); j++){
52  this->rows[data(idx).Id()-1][data(j).Id()-1] = this->dist_function(data(idx), data(j));
53  }
54  }
55  };
56 
57  pool.parallelize_loop(0, data.size(), loop, threads);
58  pool.wait_for_tasks();
59  }
60 
61  public:
62  DistanceMatrix() = default;
63 
64  explicit DistanceMatrix(mltk::Data<double> &data, const bool isDiagonal = false, const size_t threads = std::thread::hardware_concurrency()) {
65  this->threads = threads;
66  this->isDiagonal = isDiagonal;
67  this->rows = mltk::Point<double>::Matrix(data.size());
68 
69  this->compute(data);
70  }
71  };
72 }
size_t size() const
Returns the size of the dataset.
Definition: Data.hpp:208
std::size_t size() const
Returns the dimension of the point.
Definition: Point.hpp:133
A C++17 thread pool class. The user submits tasks to be executed into a queue. Whenever a thread beco...
Definition: ThreadPool.hpp:39
void wait_for_tasks()
Wait for tasks to be completed. Normally, this function waits for all tasks, both those that are curr...
Definition: ThreadPool.hpp:291
void parallelize_loop(const T1 &first_index, const T2 &index_after_last, const F &loop, ui32 num_blocks=0)
Parallelize a loop by splitting it into blocks, submitting each block separately to the thread pool,...
Definition: ThreadPool.hpp:126
Definition: DistanceMatrix.hpp:9
Definition: DistanceMatrix.hpp:34