UFJF - Machine Learning Toolkit  0.51.8
OneVsAll.hpp
1 //
2 // Created by mateus558 on 30/03/2020.
3 //
4 
5 #ifndef UFJF_MLTK_ONEVSALL_HPP
6 #define UFJF_MLTK_ONEVSALL_HPP
7 
8 #include "PrimalClassifier.hpp"
9 #include "DualClassifier.hpp"
10 #include "ufjfmltk/core/Sampling.hpp"
11 
12 namespace mltk{
13  namespace classifier {
17  template<typename T>
18  class OneVsAll : public PrimalClassifier<T>, public DualClassifier<T> {
19  private:
20  using LearnerPointer = std::shared_ptr<Learner<T> >;
21 
23  OverSampling<T> *samp_method;
25  std::vector<LearnerPointer> base_learners;
26 
27  public:
28  OneVsAll() = default;
29 
30  template<template<typename> class ClassifierType>
31  OneVsAll(Data<T> &samples, ClassifierType<T> &classifier, OverSampling<T> *samp_method = nullptr,
32  int _verbose = 0) {
33  this->samples = mltk::make_data<double>(samples);
34  this->samp_method = samp_method;
35 
36  // initialize the base m_learners if samples were given
37  if (this->samples && base_learners.size() == 0) {
38  base_learners.resize(this->samples->classes().size());
39  for (size_t i = 0; i < this->samples->classes().size(); ++i) {
40  // copy the parameters of the given classifier
41  base_learners[i] = std::make_shared<ClassifierType<T> >(classifier);
42  }
43  }
44  }
45 
46  bool train() override;
47 
48  double evaluate(const Point<T> &p, bool raw_value = false) override;
49 
50  std::string getFormulationString() override;
51 
52  };
53 
54  template<typename T>
56  auto classes = this->samples->classes();
57  size_t current_class = 0, j, n_classes = classes.size(), size = this->samples->size();
58 
59  // iterate over each base learner
60  for (auto &learner: base_learners) {
61  Data<T> temp_samples;
62 
63  // copy samples and set all classes not being considered to -1
64  temp_samples.copy(*this->samples);
65  temp_samples.setClasses({-1, 1});
66  for (j = 0; j < size; j++) {
67  temp_samples[j]->Y() = (temp_samples[j]->Y() == classes[current_class]) ? 1 : -1;
68  }
69 
70  // if a over sampling algorithm were given, apply it
71  if (samp_method) {
72  temp_samples.computeClassesDistribution();
73  (*samp_method)(temp_samples);
74  }
75 
76  // train the current learner
77  learner->setSamples(temp_samples);
78  learner->train();
79 
80  // consider the next class
81  current_class++;
82  }
83 
84  return true;
85  }
86 
87  template<typename T>
88  double OneVsAll<T>::evaluate(const Point<T> &p, bool raw_value) {
89  auto classes = this->samples->classes();
90  std::vector<double> dist_hyperplanes(base_learners.size());
91 
92  // classify the point as the class with maximum metrics
93  std::transform(base_learners.begin(), base_learners.end(), dist_hyperplanes.begin(),
94  [&p](auto &learner) {
95  return learner->evaluate(p, true);
96  });
97  size_t max_index =
98  std::max_element(dist_hyperplanes.begin(), dist_hyperplanes.end()) - dist_hyperplanes.begin();
99 
100  return classes[max_index];
101  }
102 
103  template<typename T>
105  return this->base_learners[0]->getFormulationString();
106  }
107  }
108 }
109 
110 #endif //UFJF_MLTK_ONEVSALL_HPP
size_t size() const
Returns the size of the dataset.
Definition: Data.hpp:208
const std::vector< int > classes() const
Returns a vector containing the numeric values of the classes.
Definition: Data.hpp:1831
void setClasses(const std::vector< int > &classes)
Set the classes to use in the dataset.
Definition: Data.hpp:1836
void computeClassesDistribution()
Compute the frequency of each class in the dataset.
Definition: Data.hpp:1843
mltk::Data< T > copy() const
Returns a copy of itself.
Definition: Data.hpp:1551
std::shared_ptr< Data< T > > samples
Samples used in the model training.
Definition: Learner.hpp:21
Base class for the implementation of over sampling methods.
Definition: Sampling.hpp:110
Definition: DualClassifier.hpp:16
Wrapper for the implementation of the one vs all multi class classification algorithm.
Definition: OneVsAll.hpp:18
bool train() override
Function that execute the training phase of a Learner.
Definition: OneVsAll.hpp:55
double evaluate(const Point< T > &p, bool raw_value=false) override
Returns the class of a feature point based on the trained Learner.
Definition: OneVsAll.hpp:88
std::string getFormulationString() override
getFormulationString Returns a string that represents the formulation of the learner (Primal or Dual)...
Definition: OneVsAll.hpp:104
Definition: PrimalClassifier.hpp:14
UFJF-MLTK main namespace for core functionalities.
Definition: classifier/Classifier.hpp:11