00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067 #include <math.h>
00068 #include "gnn_utilities.h"
00069 #include "gnn_weight_elimination.h"
00070
00071
00072
00073
00074
00075
00076
00077 typedef struct _gnn_weight_elimination gnn_weight_elimination;
00078
00079 struct _gnn_weight_elimination
00080 {
00081 gnn_criterion crit;
00082 gnn_criterion *subcrit;
00083 gnn_pbundle *pb;
00084 gsl_vector *w;
00085 gsl_vector *dw;
00086 double nu;
00087 double wp;
00088 };
00089
00090
00091 double
00092 gnn_weight_elimination_e (gnn_criterion *crit,
00093 const gsl_vector *y,
00094 const gsl_vector *t);
00095
00096 int
00097 gnn_weight_elimination_dy (gnn_criterion *crit,
00098 const gsl_vector *y,
00099 const gsl_vector *t,
00100 gsl_vector * dy);
00101
00102 static void
00103 gnn_weight_elimination_destroy (gnn_criterion *crit);
00104
00105
00106
00107
00108
00109
00110
00111
00112
00113
00114
00115
00116
00117
00118
00119
00120
00121
00122 double
00123 gnn_weight_elimination_e (gnn_criterion *crit,
00124 const gsl_vector *y,
00125 const gsl_vector *t)
00126 {
00127 size_t i;
00128 double E;
00129 double Ep;
00130 double wsum;
00131 gnn_weight_elimination *we;
00132
00133 assert (crit != NULL);
00134 assert (y != NULL);
00135 assert (t != NULL);
00136
00137
00138 we = (gnn_weight_elimination *) crit;
00139
00140
00141 if (y->size != t->size)
00142 GSL_ERROR_VAL ("vector sizes should be the same", GSL_EINVAL, 0.0);
00143
00144
00145 E = gnn_criterion_evaluate_e (we->subcrit, y, t);
00146
00147
00148 gnn_pbundle_get_w (we->pb, we->w);
00149
00150
00151 Ep = 0.0;
00152 for (i=0; i<we->w->size; ++i)
00153 {
00154 double wi;
00155 double wpwi;
00156
00157 wi = gsl_vector_get (we->w, i);
00158 wpwi = we->wp * we->wp + wi * wi;
00159 wpwi = GNN_MAX (GNN_TINY, wpwi);
00160 Ep += wi * wi / wpwi;
00161 }
00162
00163 return E + 0.5 * we->nu * Ep;
00164 }
00165
00166
00167
00168
00169
00170
00171
00172
00173
00174
00175
00176
00177
00178 int
00179 gnn_weight_elimination_dy (gnn_criterion *crit,
00180 const gsl_vector *y,
00181 const gsl_vector *t,
00182 gsl_vector * dy)
00183 {
00184 int i;
00185 gnn_weight_elimination *we;
00186
00187 assert (crit != NULL);
00188 assert (y != NULL);
00189 assert (t != NULL);
00190 assert (dy != NULL);
00191
00192
00193 we = (gnn_weight_elimination *) crit;
00194
00195
00196 gnn_criterion_evaluate_dy (we->subcrit, dy);
00197
00198
00199 gnn_pbundle_get_dw (we->pb, we->dw);
00200
00201
00202 for (i=0; i<we->w->size; ++i)
00203 {
00204 double wi;
00205 double dwi;
00206 double wpwi2;
00207
00208 wi = gsl_vector_get (we->w, i);
00209 dwi = gsl_vector_get (we->dw, i);
00210
00211 wpwi2 = we->wp * we->wp + wi * wi;
00212 wpwi2 = GNN_MAX (GNN_TINY, wpwi2 * wpwi2);
00213
00214 dwi += we->nu * 2 * wi * (we->wp * we->wp / wpwi2);
00215
00216 gsl_vector_set (we->dw, i, dwi);
00217 }
00218
00219
00220 gnn_pbundle_set_dw (we->pb, we->dw);
00221
00222 return 0;
00223 }
00224
00225
00226
00227
00228
00229
00230
00231
00232
00233 static void
00234 gnn_weight_elimination_destroy (gnn_criterion *crit)
00235 {
00236 gnn_weight_elimination *we;
00237
00238 assert (crit != NULL);
00239
00240
00241 we = (gnn_weight_elimination *) crit;
00242
00243
00244 if (we->subcrit != NULL)
00245 gnn_criterion_destroy (we->subcrit);
00246 if (we->pb != NULL)
00247 gnn_pbundle_destroy (we->pb);
00248 if (we->w != NULL)
00249 gsl_vector_free (we->w);
00250 if (we->dw != NULL)
00251 gsl_vector_free (we->dw);
00252 }
00253
00254
00255
00256
00257
00258
00259
00260
00261
00262
00263
00264
00265
00266
00267
00268
00269
00270
00271
00272
00273
00274 gnn_criterion *
00275 gnn_weight_elimination_new (gnn_criterion *crit,
00276 double nu,
00277 double wp,
00278 gnn_node *node)
00279 {
00280 gnn_pbundle *pb;
00281 gnn_weight_elimination *we;
00282
00283 assert (crit != NULL);
00284 assert (node != NULL);
00285
00286
00287 pb = gnn_node_sub_search_params (node, "gnn_weight");
00288
00289
00290 we = (gnn_weight_elimination *)
00291 gnn_weight_elimination_new_with_pbundle (crit, nu, wp, pb);
00292 if (we == NULL)
00293 {
00294 gnn_pbundle_destroy (pb);
00295 GSL_ERROR_VAL ("couldn't create gnn_weight_elimination regularizer",
00296 GSL_EFAILED, NULL);
00297 }
00298
00299 return (gnn_criterion *) we;
00300 }
00301
00302
00303
00304
00305
00306
00307
00308
00309
00310
00311
00312
00313
00314
00315
00316
00317 gnn_criterion *
00318 gnn_weight_elimination_new_with_type (gnn_criterion *crit,
00319 double nu,
00320 double wp,
00321 gnn_node *node,
00322 const char *type)
00323 {
00324 gnn_pbundle *pb;
00325 gnn_weight_elimination *we;
00326
00327 assert (crit != NULL);
00328 assert (node != NULL);
00329 assert (type != NULL);
00330
00331
00332
00333 pb = gnn_node_sub_search_params (node, type);
00334
00335
00336 we = (gnn_weight_elimination *)
00337 gnn_weight_elimination_new_with_pbundle (crit, nu, wp, pb);
00338 if (we == NULL)
00339 {
00340 gnn_pbundle_destroy (pb);
00341 GSL_ERROR_VAL ("couldn't create gnn_weight_elimination regularizer",
00342 GSL_EFAILED, NULL);
00343 }
00344
00345 return (gnn_criterion *) we;
00346 }
00347
00348
00349
00350
00351
00352
00353
00354
00355
00356
00357
00358
00359
00360
00361
00362 gnn_criterion *
00363 gnn_weight_elimination_new_with_pbundle (gnn_criterion *crit,
00364 double nu,
00365 double wp,
00366 gnn_pbundle *pb)
00367 {
00368 int status;
00369 size_t l;
00370 size_t size;
00371 gnn_criterion *c;
00372 gnn_weight_elimination *we;
00373
00374 assert (crit != NULL);
00375 assert (pb != NULL);
00376
00377
00378 if (nu < 0.0)
00379 {
00380 GSL_ERROR_VAL ("penalty term coefficient should be positive",
00381 GSL_EINVAL, NULL);
00382 }
00383
00384
00385 if (wp <= 0.0)
00386 {
00387 GSL_ERROR_VAL ("weight scale factor should be strictly positive",
00388 GSL_EINVAL, NULL);
00389 }
00390
00391
00392 we = (gnn_weight_elimination *) malloc (sizeof (gnn_weight_elimination));
00393 if (we == NULL)
00394 {
00395 GSL_ERROR_VAL ("couldn't alloc memory for gnn_weight_elimination",
00396 GSL_ENOMEM, NULL);
00397 }
00398
00399
00400 c = (gnn_criterion *) we;
00401
00402
00403 size = gnn_criterion_get_size (crit);
00404
00405
00406 status = gnn_criterion_init (c,
00407 "gnn_weight_elimination",
00408 size,
00409 gnn_weight_elimination_e,
00410 gnn_weight_elimination_dy,
00411 gnn_weight_elimination_destroy);
00412 if (status)
00413 {
00414 gnn_criterion_destroy (c);
00415 GSL_ERROR_VAL ("couldn't initialize gnn_weight_elimination",
00416 GSL_EFAILED, NULL);
00417 }
00418
00419
00420 l = gnn_pbundle_get_size (pb);
00421
00422
00423 we->subcrit = crit;
00424 we->pb = pb;
00425 we->w = gsl_vector_alloc (l);
00426 we->dw = gsl_vector_alloc (l);
00427 we->nu = nu;
00428 we->wp = wp;
00429
00430 if (we->dw == NULL || we->w == NULL)
00431 {
00432 gnn_criterion_destroy (c);
00433 GSL_ERROR_VAL ("couldn't allocate memory for internal "
00434 "gnn_weight_elimination buffer", GSL_EFAILED, NULL);
00435 }
00436
00437 return c;
00438 }
00439
00440
00441
00442
00443
00444
00445
00446
00447
00448
00449
00450
00451 int
00452 gnn_weight_elimination_set_nu (gnn_criterion *crit, double nu)
00453 {
00454 gnn_weight_elimination *we;
00455
00456 assert (crit != NULL);
00457
00458 if (nu < 0.0)
00459 {
00460 GSL_ERROR ("penalty term coefficient should be positive", GSL_EINVAL);
00461 }
00462
00463 we = (gnn_weight_elimination *) crit;
00464 we->nu = nu;
00465
00466 return 0;
00467 }
00468
00469
00470
00471
00472
00473
00474
00475
00476
00477
00478
00479 double
00480 gnn_weight_elimination_get_nu (gnn_criterion *crit)
00481 {
00482 gnn_weight_elimination *we;
00483
00484 assert (crit != NULL);
00485
00486 we = (gnn_weight_elimination *) crit;
00487 return we->nu;
00488 }
00489
00490
00491
00492
00493
00494
00495
00496
00497
00498
00499
00500
00501 int
00502 gnn_weight_elimination_set_wp (gnn_criterion *crit, double wp)
00503 {
00504 gnn_weight_elimination *we;
00505
00506 assert (crit != NULL);
00507
00508 if (wp <= 0.0)
00509 {
00510 GSL_ERROR ("weight scale coefficient should be strictly positive",
00511 GSL_EINVAL);
00512 }
00513
00514 we = (gnn_weight_elimination *) crit;
00515 we->wp = wp;
00516
00517 return 0;
00518 }
00519
00520
00521
00522
00523
00524
00525
00526
00527
00528
00529
00530 double
00531 gnn_weight_elimination_get_wp (gnn_criterion *crit)
00532 {
00533 gnn_weight_elimination *we;
00534
00535 assert (crit != NULL);
00536
00537 we = (gnn_weight_elimination *) crit;
00538 return we->wp;
00539 }
00540
00541
00542
00543
00544
00545
00546
00547
00548
00549
00550
00551
00552 int
00553 gnn_weight_elimination_prun (gnn_criterion *crit, double th)
00554 {
00555 size_t i;
00556 size_t size;
00557 gnn_weight_elimination *we;
00558
00559 assert (crit != NULL);
00560
00561
00562 we = (gnn_weight_elimination *) crit;
00563
00564
00565 size = gnn_pbundle_get_size (we->pb);
00566
00567
00568
00569 for (i=0; i<size; ++i)
00570 {
00571 double wi;
00572
00573 wi = gnn_pbundle_get_w_at (we->pb, i);
00574 if (fabs (wi) < th)
00575 {
00576 gnn_pbundle_set_w_at (we->pb, i, 0.0);
00577 gnn_pbundle_set_f_at (we->pb, i, 1);
00578 }
00579 }
00580
00581 return 0;
00582 }
00583
00584
00585