5 #ifndef UFJF_MLTK_KNN_HPP 
    6 #define UFJF_MLTK_KNN_HPP 
    9 #include "ufjfmltk/core/DistanceMatrix.hpp" 
   10 #include "ufjfmltk/core/DistanceMetric.hpp" 
   11 #include "ufjfmltk/core/ThreadPool.hpp" 
   16         namespace classifier {
 
   20             template<
typename T = 
double, 
typename Callable = metrics::dist::Eucl
idean<T> >
 
   26                     Callable dist_function;
 
   28                     bool precomputed = 
false;
 
   29                     std::string algorithm = 
"brute";
 
   35                     explicit KNNClassifier(
size_t _k, std::string _algorithm = 
"brute")
 
   36                             : k(_k), algorithm(_algorithm) {}
 
   39                             : k(_k), algorithm(_algorithm) {
 
   40                         this->
samples = mltk::make_data<T>(_samples);
 
   43                     bool train() 
override;
 
   47                     Callable& metric(){ 
return dist_function; }
 
   50                         this->distances = _distances;
 
   51                         this->precomputed = 
true;
 
   55                         this->precomputed = 
true;
 
   62             template<
typename T, 
typename Callable>
 
   64                 assert(this->samples->dim() == p.
size());
 
   65                 auto points = this->samples->points();
 
   67                 std::vector<int> classes = this->samples->classes();
 
   68                 std::vector<size_t> idx(distances.
size());
 
   69                 std::vector<PointPointer<T>> neigh;
 
   70                 auto p0 = std::make_shared<Point<T> >(p);
 
   72                 if(algorithm == 
"brute"){
 
   74                     std::iota(idx.begin(), idx.end(), 0);
 
   77                         std::transform(points.begin(), points.end(), distances.begin(),
 
   78                                     [&p0, 
this](
const std::shared_ptr<
Point<T> > q) {
 
   79                                         return this->dist_function(*p0, *q);
 
   84                         std::nth_element(idx.begin(), idx.begin() + this->k, idx.end(), [&distances](
size_t i1, 
size_t i2) {
 
   85                             return distances[i1] < distances[i2];
 
   87                     }
else if(!this->distances.isDiagonalMatrix()){
 
   88                         size_t idp = p.
Id()-1;
 
   89                         std::nth_element(idx.begin(), idx.begin() + this->k, idx.end(), [&idp, &distances, &points, 
this](
size_t i1, 
size_t i2) {
 
   90                             size_t id1 = points[i1]->Id()-1;
 
   91                             size_t id2 = points[i2]->Id()-1;
 
   93                             return this->distances[idp][id1] < this->distances[idp][id2];
 
   96                         std::nth_element(idx.begin(), idx.begin() + this->k, idx.end(), [&p, 
this, &points](
size_t i1, 
size_t i2) {
 
   97                             size_t id1 = points[i1]->Id()-1;
 
   98                             size_t id2 = points[i2]->Id()-1;
 
   99                             size_t idp = p.Id()-1;
 
  101                             if(idp == id1 || idp == id2) return false;
 
  103                             size_t idp1 = (idp > id1) ? idp : id1;
 
  104                             size_t id1p = (idp > id1) ? id1 : idp;
 
  106                             size_t idp2 = (idp > id2) ? idp : id2;
 
  107                             size_t id2p = (idp > id2) ? id2 : idp;
 
  109                             return this->distances[idp1][id1p] < this->distances[idp2][id2p];
 
  114                 auto calculateFrequency = [&idx, &points, &neigh, 
this](
int c) {
 
  115                     if (algorithm == 
"brute") {
 
  116                         return std::count_if(idx.begin(), idx.begin() + this->k, [&points, &c](
size_t id) { 
 
  117                             return points[id]->Y() == c; 
 
  120                         return std::count_if(neigh.begin(), neigh.end(), [&c](
auto point) { 
 
  121                             return point->Y() == c; 
 
  127                 std::pair<int, size_t> maxDetails{0, 0}; 
 
  130                 double max_prob = 0.0;
 
  133                 for(
size_t i = 0; i < classes.size(); ++i) {
 
  134                     int freq = calculateFrequency(classes[i]);
 
  136                     if(freq > maxDetails.first) {
 
  137                         double prob = (freq+s)/(k+classes.size()*s);
 
  139                         maxDetails.first = freq;
 
  140                         maxDetails.second = i; 
 
  146                 this->pred_prob = (1-max_prob > 1E-7) ? 1: max_prob;
 
  148                 return classes[maxDetails.second];
 
  151             template<
typename T, 
typename Callable>
 
std::shared_ptr< Data< T > > samples
Samples used in the model training.
Definition: Learner.hpp:21
 
std::size_t size() const
Returns the dimension of the point.
Definition: Point.hpp:133
 
size_t const  & Id() const
Returns the id of the point.
Definition: Point.hpp:180
 
Wrapper for the implementation of the K-Nearest Neighbors classifier algorithm.
Definition: KNNClassifier.hpp:21
 
bool train() override
Function that execute the training phase of a Learner.
Definition: KNNClassifier.hpp:152
 
double evaluate(const Point< T > &p, bool raw_value=false) override
Returns the class of a feature point based on the trained Learner.
Definition: KNNClassifier.hpp:63
 
Definition: PrimalClassifier.hpp:14
 
Definition: DistanceMatrix.hpp:9
 
Definition: DistanceMatrix.hpp:34
 
UFJF-MLTK main namespace for core functionalities.
Definition: classifier/Classifier.hpp:11