UFJF - Machine Learning Toolkit  0.51.8
AdaBoostClassifier.hpp
1 //
2 // Created by mateuscmarim on 20/11/2020.
3 //
4 
5 #pragma once
6 
7 #include "Ensemble.hpp"
8 
9 namespace mltk {
10  namespace ensemble {
11  template<typename T>
12  class AdaBoostClassifier : public Ensemble<T>, public classifier::Classifier<T> {
13  private:
14  size_t n_estimators{};
15  Point<double> weights;
16  Point<double> _alpha;
17  public:
18  AdaBoostClassifier()=default;
19 
20  template<template<typename > class Estimator>
21  explicit AdaBoostClassifier(const Data<T> &data, Estimator<T> base_estimator, const size_t n_estimators): n_estimators(n_estimators) {
22  this->samples = mltk::make_data<T>(data);
23  this->m_learners.resize(n_estimators);
24 
25  for(size_t i = 0; i < n_estimators; i++){
26  this->m_learners[i] = std::make_shared< Estimator<T> >(base_estimator);
27  }
28  }
29 
30  bool train() override {
31  size_t _size = this->samples->size(), K = this->samples->classes().size();
32  Point<double> err(n_estimators, 0.0);
33  Point<double> alpha(n_estimators, 0.0);
34  // Initialize weights to an uniform distribution
35  this->weights.assign(_size, 1.0/_size);
36 
37  for(size_t m = 0; m < n_estimators; m++){
38  auto learner = this->m_learners[m];
39  learner->setSeed(this->seed+m);
40  learner->setSamples(this->samples);
41  learner->train();
42  // compute the probability of miss classification for each point
43  Point<double> errors(_size, 0.0);
44  for(size_t i = 0; i < _size; i++){
45  auto point = (*this->samples)[i];
46 
47  if(point->Y() != learner->evaluate(*point)) errors[i] = weights[i];
48  }
49  // compute the estimator error as the weighted average of each point error
50  err[m] = mltk::dot(weights, errors)/weights.sum();
51  // compute alpha to be used on weight update
52  alpha[m] = ((err[m] > 0)?std::log((1.0-err[m])/err[m]):1) + std::log(K - 1);
53  weights *= mltk::exp(alpha[m]*errors);
54  // Normalize weights to form a probability distribution
55  weights /= weights.sum();
56  }
57  this->_alpha.X().resize(n_estimators);
58  this->_alpha = alpha;
59  return true;
60  }
61 
62  double evaluate(const Point<T>& p, bool raw_value=false) override {
63  auto classes = this->samples->classes();
64  Point<double> prob(classes.size(), 0.0);
65  for(size_t c = 0; c < classes.size(); c++) {
66  for(size_t m = 0; m < n_estimators; m++) {
67  if(this->m_learners[m]->evaluate(p) == classes[c]) {
68  prob[c] += this->_alpha[m];
69  }
70  }
71  }
72  size_t class_pos = std::max_element(prob.X().begin(), prob.X().end()) - prob.X().begin();
73  return classes[class_pos];
74  }
75 
76  std::string getFormulationString() override {
77  return this->m_learners[0]->getFormulationString();
78  }
79  };
80  }
81 }
std::shared_ptr< Data< T > > samples
Samples used in the model training.
Definition: Learner.hpp:21
size_t seed
seed for random operations.
Definition: Learner.hpp:46
Rep const & X() const
Returns the attributes representation of the point (std::vector by default).
Definition: Point.hpp:139
T sum(const std::function< T(T)> &f=[](T const &t) { return t;}) const
Compute the sum of the components of the point.
Definition: Point.hpp:285
Definition: classifier/Classifier.hpp:17
Definition: AdaBoostClassifier.hpp:12
std::string getFormulationString() override
getFormulationString Returns a string that represents the formulation of the learner (Primal or Dual)...
Definition: AdaBoostClassifier.hpp:76
double evaluate(const Point< T > &p, bool raw_value=false) override
Returns the class of a feature point based on the trained Learner.
Definition: AdaBoostClassifier.hpp:62
bool train() override
Function that execute the training phase of a Learner.
Definition: AdaBoostClassifier.hpp:30
Namespace for ensemble methods.
Definition: ensemble/Ensemble.hpp:16
std::vector< LearnerPointer< T > > m_learners
Pointer to base learner used by the ensemble method.
Definition: ensemble/Ensemble.hpp:22
UFJF-MLTK main namespace for core functionalities.
Definition: classifier/Classifier.hpp:11
T dot(const Point< T, R > &p, const Point< T, R > &p1)
Computes the dot product with a vector.
Definition: Point.hpp:528