00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072
00073
00074
00075
00076
00077 #include "gnn_utilities.h"
00078 #include "gnn_rprop.h"
00079
00080
00081
00082
00083
00084
00085
00086 typedef struct _gnn_rprop gnn_rprop;
00087
00088 struct _gnn_rprop
00089 {
00090 gnn_trainer trainer;
00091 gsl_vector *w;
00092 gsl_vector *dwnew;
00093 gsl_vector *dwold;
00094 gsl_vector *deltaw;
00095 gsl_vector *delta;
00096 double nup;
00097 double num;
00098 double deltamax;
00099 double deltamin;
00100 double delta0;
00101 };
00102
00103
00104 static int
00105 gnn_rprop_train (gnn_trainer *trainer);
00106
00107 static void
00108 gnn_rprop_destroy (gnn_trainer *trainer);
00109
00110
00111
00112
00113
00114
00115
00116
00117
00118
00119
00120
00121
00122
00123 static int
00124 gnn_rprop_reset (gnn_trainer *trainer)
00125 {
00126 gnn_rprop *rtrainer;
00127
00128 assert (trainer != NULL);
00129
00130 rtrainer = (gnn_rprop *) trainer;
00131
00132
00133 gnn_node_param_get (trainer->node, rtrainer->w);
00134
00135
00136 gsl_vector_set_all (rtrainer->delta, rtrainer->delta0);
00137 gsl_vector_set_zero (rtrainer->dwold);
00138
00139 return 0;
00140 }
00141
00142
00143
00144
00145
00146
00147
00148
00149 static int
00150 gnn_rprop_train (gnn_trainer *trainer)
00151 {
00152 size_t k;
00153 size_t param_size;
00154 gnn_rprop *rtrainer;
00155
00156
00157 rtrainer = (gnn_rprop *) trainer;
00158
00159
00160 gnn_trainer_batch_process (trainer);
00161
00162
00163 gsl_vector_memcpy (rtrainer->dwnew, gnn_trainer_batch_get_dw (trainer));
00164
00165
00166 gnn_trainer_batch_next (trainer);
00167
00168
00169 param_size = gnn_node_param_get_size (gnn_trainer_get_node (trainer));
00170
00171
00172 for (k=0; k<param_size; ++k)
00173 {
00174 double dwnewk;
00175 double dwoldk;
00176 double deltak;
00177 double deltawk;
00178 double wk;
00179
00180 dwnewk = gsl_vector_get (rtrainer->dwnew, k);
00181 dwoldk = gsl_vector_get (rtrainer->dwold, k);
00182 deltak = gsl_vector_get (rtrainer->delta, k);
00183 wk = gsl_vector_get (rtrainer->w, k);
00184
00185 if (dwnewk * dwoldk > 0)
00186 {
00187 deltak = GNN_MIN (rtrainer->nup * deltak, rtrainer->deltamax);
00188 deltawk = - GNN_SIGN (dwnewk) * deltak;
00189 wk = wk + deltawk;
00190 dwoldk = dwnewk;
00191 }
00192 else if (dwnewk * dwoldk < 0)
00193 {
00194 deltak = GNN_MAX (rtrainer->num * deltak, rtrainer->deltamin);
00195 dwoldk = 0;
00196 }
00197 else
00198 {
00199 deltawk = - GNN_SIGN (dwnewk) * deltak;
00200 wk = wk + deltawk;
00201 dwoldk = dwnewk;
00202 }
00203
00204 gsl_vector_set (rtrainer->dwold, k, dwoldk);
00205 gsl_vector_set (rtrainer->deltaw, k, deltawk);
00206 gsl_vector_set (rtrainer->delta, k, deltak);
00207 gsl_vector_set (rtrainer->w, k, wk);
00208 }
00209
00210
00211 gnn_node_param_set (trainer->node, rtrainer->w);
00212
00213 return 0;
00214 }
00215
00216
00217
00218
00219
00220
00221
00222 static void
00223 gnn_rprop_destroy (gnn_trainer *trainer)
00224 {
00225 gnn_rprop *rtrainer;
00226
00227 assert (trainer != NULL);
00228
00229 rtrainer = (gnn_rprop *) trainer;
00230
00231 if (rtrainer->w != NULL)
00232 gsl_vector_free (rtrainer->w);
00233 if (rtrainer->dwnew != NULL)
00234 gsl_vector_free (rtrainer->dwnew);
00235 if (rtrainer->dwold != NULL)
00236 gsl_vector_free (rtrainer->dwold);
00237 if (rtrainer->delta != NULL)
00238 gsl_vector_free (rtrainer->delta);
00239
00240 return;
00241 }
00242
00243
00244
00245
00246
00247
00248
00249
00250
00251
00252
00253
00254
00255
00256
00257
00258
00259
00260
00261
00262
00263
00264 gnn_trainer *
00265 gnn_rprop_new (gnn_node *node,
00266 gnn_criterion *crit,
00267 gnn_dataset *data,
00268 double delta0,
00269 double deltamax)
00270 {
00271 int status;
00272 size_t param_size;
00273 gnn_trainer *trainer;
00274 gnn_rprop *rtrainer;
00275
00276
00277 if (delta0 <= 0.0)
00278 {
00279 GSL_ERROR_VAL ("delta0 should be stricly positive",
00280 GSL_EINVAL, NULL);
00281 }
00282
00283 if (deltamax <= 0.0)
00284 {
00285 GSL_ERROR_VAL ("deltamax should be strictly positive",
00286 GSL_EINVAL, NULL);
00287 }
00288
00289
00290 rtrainer = (gnn_rprop *) malloc (sizeof (gnn_rprop));
00291 if (rtrainer == NULL)
00292 {
00293 GSL_ERROR_VAL ("couldn't allocate memory for gnn_rprop",
00294 GSL_ENOMEM, NULL);
00295 }
00296
00297
00298 trainer = (gnn_trainer *) rtrainer;
00299
00300
00301 status = gnn_trainer_init (trainer,
00302 "gnn_rprop",
00303 node,
00304 crit,
00305 data,
00306 gnn_rprop_reset,
00307 gnn_rprop_train,
00308 gnn_rprop_destroy);
00309 if (status)
00310 {
00311 GSL_ERROR_VAL ("couldn't initialize gnn_rprop",
00312 GSL_EFAILED, NULL);
00313 }
00314
00315
00316 param_size = gnn_node_param_get_size (node);
00317
00318 rtrainer->delta0 = delta0;
00319 rtrainer->deltamin = 0.000001;
00320 rtrainer->deltamax = deltamax;
00321 rtrainer->nup = 1.2;
00322 rtrainer->num = 0.5;
00323 rtrainer->w = gsl_vector_alloc (param_size);
00324 rtrainer->dwold = gsl_vector_alloc (param_size);
00325 rtrainer->dwnew = gsl_vector_alloc (param_size);
00326 rtrainer->delta = gsl_vector_alloc (param_size);
00327 rtrainer->deltaw = gsl_vector_alloc (param_size);
00328 if ( rtrainer->w == NULL
00329 || rtrainer->dwold == NULL
00330 || rtrainer->dwnew == NULL
00331 || rtrainer->delta == NULL
00332 || rtrainer->deltaw == NULL )
00333 {
00334 gnn_trainer_destroy (trainer);
00335 GSL_ERROR_VAL ("couldn't allocate memory for gnn_rprop",
00336 GSL_ENOMEM, NULL);
00337 }
00338
00339 return trainer;
00340 }
00341
00342
00343
00344
00345
00346
00347
00348
00349
00350
00351
00352
00353
00354
00355 gnn_trainer *
00356 gnn_rprop_standard_new (gnn_node *node,
00357 gnn_criterion *crit,
00358 gnn_dataset *data)
00359 {
00360 return gnn_rprop_new (node, crit, data, 0.1, 50.0);
00361 }
00362
00363
00364
00365
00366
00367
00368
00369
00370
00371
00372
00373
00374
00375
00376
00377 int
00378 gnn_rprop_set_delta0 (gnn_trainer *trainer, double delta0)
00379 {
00380 gnn_rprop *rtrainer;
00381
00382 assert (trainer != NULL);
00383
00384
00385 if (delta0 <= 0.0)
00386 GSL_ERROR ("delta0 should be strictly positive", GSL_EINVAL);
00387
00388
00389 rtrainer = (gnn_rprop *) trainer;
00390
00391
00392 rtrainer->delta0 = delta0;
00393
00394 return 0;
00395 }
00396
00397
00398
00399
00400
00401
00402
00403
00404
00405
00406 double
00407 gnn_rprop_get_delta0 (gnn_trainer *trainer)
00408 {
00409 gnn_rprop *rtrainer;
00410
00411 assert (trainer != NULL);
00412
00413 rtrainer = (gnn_rprop *) trainer;
00414 return rtrainer->delta0;
00415 }
00416
00417
00418
00419
00420
00421
00422
00423
00424
00425
00426
00427
00428
00429
00430 int
00431 gnn_rprop_set_deltamin (gnn_trainer *trainer, double deltamin)
00432 {
00433 gnn_rprop *rtrainer;
00434
00435 assert (trainer != NULL);
00436
00437
00438 if (deltamin <= 0.0)
00439 GSL_ERROR ("deltamin should be strictly positive", GSL_EINVAL);
00440
00441
00442 rtrainer = (gnn_rprop *) trainer;
00443
00444
00445 rtrainer->deltamin = deltamin;
00446
00447 return 0;
00448 }
00449
00450
00451
00452
00453
00454
00455
00456
00457
00458
00459 double
00460 gnn_rprop_get_deltamin (gnn_trainer *trainer)
00461 {
00462 gnn_rprop *rtrainer;
00463
00464 assert (trainer != NULL);
00465
00466 rtrainer = (gnn_rprop *) trainer;
00467 return rtrainer->deltamin;
00468 }
00469
00470
00471
00472
00473
00474
00475
00476
00477
00478
00479
00480
00481
00482
00483 int
00484 gnn_rprop_set_deltamax (gnn_trainer *trainer, double deltamax)
00485 {
00486 gnn_rprop *rtrainer;
00487
00488 assert (trainer != NULL);
00489
00490
00491 if (deltamax <= 0.0)
00492 GSL_ERROR ("deltamax should be strictly positive", GSL_EINVAL);
00493
00494
00495 rtrainer = (gnn_rprop *) trainer;
00496
00497
00498 rtrainer->deltamax = deltamax;
00499
00500 return 0;
00501 }
00502
00503
00504
00505
00506
00507
00508
00509
00510
00511
00512 double
00513 gnn_rprop_get_deltamax (gnn_trainer *trainer)
00514 {
00515 gnn_rprop *rtrainer;
00516
00517 assert (trainer != NULL);
00518
00519 rtrainer = (gnn_rprop *) trainer;
00520 return rtrainer->deltamax;
00521 }
00522
00523
00524
00525
00526
00527
00528
00529
00530
00531
00532
00533
00534
00535
00536 int
00537 gnn_rprop_set_nup (gnn_trainer *trainer, double nup)
00538 {
00539 gnn_rprop *rtrainer;
00540
00541 assert (trainer != NULL);
00542
00543
00544 if (nup <= 1.0)
00545 GSL_ERROR ("nup should be greater than 1", GSL_EINVAL);
00546
00547
00548 rtrainer = (gnn_rprop *) trainer;
00549
00550
00551 rtrainer->nup = nup;
00552
00553 return 0;
00554 }
00555
00556
00557
00558
00559
00560
00561
00562
00563
00564
00565 double
00566 gnn_rprop_get_nup (gnn_trainer *trainer)
00567 {
00568 gnn_rprop *rtrainer;
00569
00570 assert (trainer != NULL);
00571
00572 rtrainer = (gnn_rprop *) trainer;
00573 return rtrainer->nup;
00574 }
00575
00576
00577
00578
00579
00580
00581
00582
00583
00584
00585
00586
00587
00588
00589 int
00590 gnn_rprop_set_num (gnn_trainer *trainer, double num)
00591 {
00592 gnn_rprop *rtrainer;
00593
00594 assert (trainer != NULL);
00595
00596
00597 if (num <= 0.0 || num >= 1.0)
00598 GSL_ERROR ("num should be within (0,1)", GSL_EINVAL);
00599
00600
00601 rtrainer = (gnn_rprop *) trainer;
00602
00603
00604 rtrainer->num = num;
00605
00606 return 0;
00607 }
00608
00609
00610
00611
00612
00613
00614
00615
00616
00617
00618 double
00619 gnn_rprop_get_num (gnn_trainer *trainer)
00620 {
00621 gnn_rprop *rtrainer;
00622
00623 assert (trainer != NULL);
00624
00625 rtrainer = (gnn_rprop *) trainer;
00626 return rtrainer->num;
00627 }
00628