5 #ifndef UFJF_MLTK_ONEVSALL_HPP
6 #define UFJF_MLTK_ONEVSALL_HPP
10 #include "ufjfmltk/core/Sampling.hpp"
13 namespace classifier {
20 using LearnerPointer = std::shared_ptr<Learner<T> >;
25 std::vector<LearnerPointer> base_learners;
30 template<
template<
typename>
class ClassifierType>
33 this->samples = mltk::make_data<double>(
samples);
34 this->samp_method = samp_method;
37 if (this->samples && base_learners.
size() == 0) {
38 base_learners.resize(this->samples->
classes().size());
39 for (
size_t i = 0; i < this->samples->
classes().size(); ++i) {
41 base_learners[i] = std::make_shared<ClassifierType<T> >(classifier);
46 bool train()
override;
56 auto classes = this->samples->classes();
57 size_t current_class = 0, j, n_classes = classes.size(), size = this->samples->size();
60 for (
auto &learner: base_learners) {
64 temp_samples.
copy(*this->samples);
66 for (j = 0; j < size; j++) {
67 temp_samples[j]->Y() = (temp_samples[j]->Y() == classes[current_class]) ? 1 : -1;
73 (*samp_method)(temp_samples);
77 learner->setSamples(temp_samples);
89 auto classes = this->samples->classes();
90 std::vector<double> dist_hyperplanes(base_learners.size());
93 std::transform(base_learners.begin(), base_learners.end(), dist_hyperplanes.begin(),
95 return learner->evaluate(p, true);
98 std::max_element(dist_hyperplanes.begin(), dist_hyperplanes.end()) - dist_hyperplanes.begin();
100 return classes[max_index];
105 return this->base_learners[0]->getFormulationString();
size_t size() const
Returns the size of the dataset.
Definition: Data.hpp:208
const std::vector< int > classes() const
Returns a vector containing the numeric values of the classes.
Definition: Data.hpp:1831
void setClasses(const std::vector< int > &classes)
Set the classes to use in the dataset.
Definition: Data.hpp:1836
void computeClassesDistribution()
Compute the frequency of each class in the dataset.
Definition: Data.hpp:1843
mltk::Data< T > copy() const
Returns a copy of itself.
Definition: Data.hpp:1551
std::shared_ptr< Data< T > > samples
Samples used in the model training.
Definition: Learner.hpp:21
Base class for the implementation of over sampling methods.
Definition: Sampling.hpp:110
Definition: DualClassifier.hpp:16
Wrapper for the implementation of the one vs all multi class classification algorithm.
Definition: OneVsAll.hpp:18
bool train() override
Function that execute the training phase of a Learner.
Definition: OneVsAll.hpp:55
double evaluate(const Point< T > &p, bool raw_value=false) override
Returns the class of a feature point based on the trained Learner.
Definition: OneVsAll.hpp:88
std::string getFormulationString() override
getFormulationString Returns a string that represents the formulation of the learner (Primal or Dual)...
Definition: OneVsAll.hpp:104
Definition: PrimalClassifier.hpp:14
UFJF-MLTK main namespace for core functionalities.
Definition: classifier/Classifier.hpp:11