15 #ifndef UFJF_MLTK_SMO_H
16 #define UFJF_MLTK_SMO_H
24 namespace classifier {
29 struct smo_learning_data {
42 const double TOL = 0.0001;
43 std::vector<smo_learning_data> l_data;
46 bool examine_example(
int i1);
48 bool max_errors(
int i1,
double e1);
50 bool iterate_non_bound(
int i1);
52 bool iterate_all_set(
int i1);
54 int take_step(
int i1,
int i2);
56 double function(
int index);
58 bool training_routine();
66 explicit SMO(
const Data<T>&
samples, KernelType kernel_type = KernelType::INNER_PRODUCT,
67 double param = 0,
int verbose = 0);
69 bool train()
override;
75 SMO<T>::SMO(
const Data<T>& samples, KernelType kernel_type,
double param,
int verbose) {
76 this->samples = mltk::make_data<T>(samples);
77 this->verbose = verbose;
78 this->kernel_type = kernel_type;
79 this->kernel_param = param;
85 if (this->head !=
nullptr) {
86 int_dll::free(&this->head);
89 if(this->kernel)
delete this->kernel;
94 size_t i = 0, size = this->samples->size(), dim = this->samples->dim();
98 std::vector<double> w_saved;
101 this->l_data.resize(size);
104 this->solution.bias = 0;
105 for (i = 0; i < size; i++)
106 (*this->samples)[i]->Alpha() = 0;
107 this->alpha.assign(size, 0.0);
108 if(this->kernel)
delete this->kernel;
109 this->kernel =
new mltk::Kernel<T>(this->kernel_type, this->kernel_param);
110 this->kernel->compute(this->samples);
114 ret = training_routine();
116 norm = this->kernel->featureSpaceNorm(this->samples);
117 if (this->kernel->getType() == 0)
118 w_saved = this->getWeight().X();
120 if (this->kernel->getType() == 1 && this->kernel->getParam() == 1)
121 w_saved = this->getDualWeightProdInt().X();
123 w_saved = this->getDualWeight().X();
126 this->solution.w = w_saved;
127 this->solution.margin = 1.0 / norm;
128 this->solution.alpha.assign(size, 0.0);
129 this->solution.svs = 0;
131 for (i = 0; i < size; ++i) {
132 this->solution.alpha[i] = (*this->samples)[i]->Alpha();
133 if ((*this->samples)[i]->Alpha() > 0) ++this->solution.svs;
134 if ((*this->samples)[i]->Alpha() > this->C) ret =
false;
138 std::cout <<
"Number of Support Vectors: " << this->solution.svs << std::endl;
139 std::cout <<
"Margin found: " << this->solution.margin <<
"\n\n";
141 if (this->verbose > 1) {
142 std::vector<int> fnames = this->samples->getFeaturesNames();
143 for (i = 0; i < dim; i++)
144 std::cout <<
"W[" << i <<
"]: " << this->solution.w[i] << std::endl;
145 std::cout <<
"Bias: " << this->solution.bias <<
"\n\n";
149 int_dll::free(&this->head);
150 this->l_data.clear();
158 size_t i = 0, size = this->samples->size();
165 for (i = 0; i < size; ++i) this->l_data[i].done = 0;
166 this->l_data[i1].done = 1;
169 auto p = (*this->samples)[i1];
172 if (alpha1 > 0 && alpha1 < this->C) e1 = this->l_data[i1].error;
173 else e1 =
function(i1) - y1;
179 if ((r1 < -this->TOL && alpha1 < this->C) || (r1 > this->TOL && alpha1 > 0)) {
180 if (max_errors(i1, e1))
return true;
181 else if (iterate_non_bound(i1))
return true;
182 else if (iterate_all_set(i1))
return true;
183 }
else if (this->verbose > 2) std::cout <<
"Return0 -1\n";
189 bool SMO<T>::max_errors(
int i1,
double e1) {
195 int_dll *list =
nullptr;
197 if (this->verbose > 2) std::cout <<
" Max errors iterations\n";
200 list = this->head->next;
201 while (list !=
nullptr) {
203 if (this->l_data[k].done == 0 && (*this->samples)[k]->Alpha() < this->C) {
204 e2 = this->l_data[k].error;
205 temp = fabs(e1 - e2);
215 return (i2 >= 0 && take_step(i1, i2));
219 bool SMO<T>::iterate_non_bound(
int i1) {
221 int_dll *list =
nullptr;
223 if (this->verbose > 2) printf(
" Non-bound iteration\n");
226 list = this->head->next;
227 while (list !=
nullptr) {
229 if (this->l_data[k].done == 0 && (*this->samples)[k]->Alpha() < this->C)
230 if (take_step(i1, k))
return true;
238 bool SMO<T>::iterate_all_set(
int i1) {
242 size_t size = this->samples->size();
244 if (this->verbose > 2) std::cout <<
" All-set iteration\n";
251 for (k = k0; k < size + k0; ++k) {
253 if (this->l_data[i2].done == 0 && take_step(i1, i2))
260 int SMO<T>::take_step(
int i1,
int i2) {
261 int i = 0, y1 = 0, y2 = 0, s = 0;
262 double alpha1 = 0, alpha2 = 0, new_alpha1 = 0, new_alpha2 = 0;
263 double e1 = 0, e2 = 0, min_val = 0, max_val = 0, eta = 0;
264 double max_val_f = 0, min_val_f = 0;
265 double bnew = 0, b = 0;
266 double t1 = 0, t2 = 0, error_tot = 0;
267 int_dll *itr =
nullptr;
268 dMatrix *matrix = this->kernel->getKernelMatrixPointer();
271 this->l_data[i2].done = 1;
274 b = -this->solution.bias;
275 y1 = (*this->samples)[i1]->Y();
276 y2 = (*this->samples)[i2]->Y();
278 alpha1 = (*this->samples)[i1]->Alpha();
279 alpha2 = (*this->samples)[i2]->Alpha();
282 if (alpha1 > 0 && alpha1 < this->C) e1 = this->l_data[i1].error;
283 else e1 =
function(i1) - y1;
286 if (alpha2 > 0 && alpha2 < this->C) e2 = this->l_data[i2].error;
287 else e2 =
function(i2) - y2;
294 min_val =
std::max(0.0, alpha2 - alpha1);
295 max_val =
std::min(
double(this->C), this->C + alpha2 - alpha1);
297 min_val =
std::max(0.0, alpha2 + alpha1 - this->C);
298 max_val =
std::min(
double(this->C), alpha1 + alpha2);
300 if (min_val == max_val) {
301 if (this->verbose > 2) std::cout <<
"return0 2\n";
306 eta = 2.0 * (*matrix)[i1][i2] - (*matrix)[i1][i1] - (*matrix)[i2][i2];
310 new_alpha2 = alpha2 + y2 * (e2 - e1) / eta;
312 if (new_alpha2 < min_val) new_alpha2 = min_val;
313 else if (new_alpha2 > max_val) new_alpha2 = max_val;
316 double c1 = eta / 2.0;
317 double c2 = y2 * (e1 - e2) - eta * alpha2;
318 min_val_f = c1 * min_val * min_val + c2 * min_val;
319 max_val_f = c1 * max_val * max_val + c2 * min_val;
321 if (min_val_f > max_val_f + this->EPS) new_alpha2 = min_val;
322 else if (min_val_f < max_val_f - this->EPS) new_alpha2 = max_val;
323 else new_alpha2 = alpha2;
327 if (fabs(new_alpha2 - alpha2) < this->EPS * (new_alpha2 + alpha2 + this->EPS)) {
328 if (this->verbose > 2)std::cout <<
"return0 3\n";
333 new_alpha1 = alpha1 - s * (new_alpha2 - alpha2);
334 if (new_alpha1 < 0) {
335 new_alpha2 += s * new_alpha1;
337 }
else if (new_alpha1 > this->C) {
338 new_alpha2 += s * (new_alpha1 - this->C);
339 new_alpha1 = this->C;
342 (*this->samples)[i1]->Alpha() = new_alpha1;
343 (*this->samples)[i2]->Alpha() = new_alpha2;
346 if (new_alpha1 > 0 && this->l_data[i1].sv ==
nullptr) {
347 int_dll *list = int_dll::append(this->head);
350 this->l_data[i1].sv = list;
351 }
else if (new_alpha1 == 0 && this->l_data[i1].sv !=
nullptr) {
352 int_dll::remove(&(this->l_data[i1].sv));
355 if (new_alpha2 > 0 && this->l_data[i2].sv ==
nullptr) {
356 int_dll *list = int_dll::append(this->head);
358 this->l_data[i2].sv = list;
359 }
else if (new_alpha2 == 0 && this->l_data[i2].sv !=
nullptr) {
360 int_dll::remove(&(this->l_data[i2].sv));
364 t1 = y1 * (new_alpha1 - alpha1);
365 t2 = y2 * (new_alpha2 - alpha2);
367 if (new_alpha1 > 0 && new_alpha1 < this->C)
368 bnew = b + e1 + t1 * (*matrix)[i1][i1] + t2 * (*matrix)[i1][i2];
370 if (new_alpha2 > 0 && new_alpha2 < this->C)
371 bnew = b + e2 + t1 * (*matrix)[i1][i2] + t2 * (*matrix)[i2][i2];
373 double b1 = 0, b2 = 0;
374 b2 = b + e1 + t1 * (*matrix)[i1][i1] + t2 * (*matrix)[i1][i2];
375 b1 = b + e2 + t1 * (*matrix)[i1][i2] + t2 * (*matrix)[i2][i2];
376 bnew = (b1 + b2) / 2.0;
381 this->solution.bias = -b;
385 itr = this->head->next;
386 while (itr !=
nullptr) {
388 if ((i != i1 && i != i2) && (*this->samples)[i]->Alpha() < C) {
389 this->l_data[i].error =
function(i) - (*this->samples)[i]->Y();
390 error_tot += this->l_data[i].error;
395 this->l_data[i1].error = 0.0;
396 this->l_data[i2].error = 0.0;
398 if (this->verbose > 1)
399 std::cout <<
"Total error= " << error_tot <<
", alpha(" << i1 <<
")= " << new_alpha1 <<
", alpha(" << i2
400 <<
")= " << new_alpha2 << std::endl;
406 double SMO<T>::function(
int index) {
409 dMatrix *matrix = this->kernel->getKernelMatrixPointer();
410 int_dll *list = this->head->next;
412 while (list !=
nullptr) {
414 if ((*this->samples)[i]->Alpha() > 0)
415 sum += (*this->samples)[i]->Alpha() * (*this->samples)[i]->Y() * (*matrix)[i][index];
418 sum += this->solution.bias;
424 bool SMO<T>::training_routine() {
425 size_t size = this->samples->size();
433 this->solution.bias = 0;
434 for (k = 0; k < size; ++k) {
435 (*this->samples)[k]->Alpha() = 0;
436 this->l_data[k].error = 0;
437 this->l_data[k].done = 0;
441 while (num_changed > 0 || examine_all) {
443 if (epoch > this->MAX_EPOCH)
return false;
447 for (k = 0; k < size; ++k) {
448 num_changed += examine_example(k);
451 for (k = 0; k < size; ++k)
452 if ((*this->samples)[k]->Alpha() > 0 && (*this->samples)[k]->Alpha() < C)
453 num_changed += examine_example(k);
455 if (examine_all == 1) examine_all = 0;
456 else if (num_changed == 0) examine_all = 1;
457 tot_changed += num_changed;
471 void SMO<T>::test_learning() {
472 size_t i = 0, size = this->samples->size();
473 for (i = 0; i < size; ++i)
474 std::cout << i + 1 <<
" -> " <<
function(i) <<
" (error=" << this->l_data[i].error <<
") (alpha="
475 << (*this->samples)[i]->Alpha() <<
")\n";
479 int SMO<T>::train_matrix(Kernel<T> *matrix) {
480 size_t i = 0, size = this->samples->size();
483 std::vector<smo_learning_data> l_data(size);
488 this->solution.bias = 0;
489 for (i = 0; i < size; i++)
490 (*this->samples)[i]->Alpha() = 0;
493 ret = training_routine();
495 norm = matrix->featureSpaceNorm(this->samples);
496 this->solution.margin = 1.0 / norm;
498 this->solution.svs = 0;
499 for (i = 0; i < size; ++i) {
500 if ((*this->samples)[i]->Alpha() > 0) ++this->solution.svs;
501 if ((*this->samples)[i]->Alpha() > this->C) ret =
false;
504 int_dll::free(&this->head);
std::shared_ptr< Data< T > > samples
Samples used in the model training.
Definition: Learner.hpp:21
int verbose
Verbose level of the output.
Definition: Learner.hpp:42
Definition: Solution.hpp:13
Definition: DualClassifier.hpp:16
bool train() override
Function that execute the training phase of a Learner.
Definition: SMO.hpp:93
A helper class to measure execution time for benchmarking purposes.
Definition: ThreadPool.hpp:503
UFJF-MLTK main namespace for core functionalities.
Definition: classifier/Classifier.hpp:11
T min(const Point< T, R > &p)
Returns the min value of the point.
Definition: Point.hpp:557
T max(const Point< T, R > &p)
Returns the max value of the point.
Definition: Point.hpp:544