00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070 #include <math.h>
00071 #include "gnn_conjugate_gradient.h"
00072
00073
00074
00075
00076
00077
00078
00079 static int
00080 gnn_conjugate_gradient_reset (gnn_trainer *trainer);
00081
00082 static int
00083 gnn_conjugate_gradient_train (gnn_trainer *trainer);
00084
00085 static void
00086 gnn_conjugate_gradient_destroy (gnn_trainer *trainer);
00087
00088
00089
00090
00091
00092
00093
00094
00095
00096
00097
00098
00099
00100
00101
00102 static int
00103 gnn_conjugate_gradient_reset (gnn_trainer *trainer)
00104 {
00105 gnn_conjugate_gradient *cgtrainer;
00106
00107 assert (trainer != NULL);
00108
00109 cgtrainer = (gnn_conjugate_gradient *) trainer;
00110
00111
00112 cgtrainer->iteration = 0;
00113
00114
00115 gsl_vector_set_zero (cgtrainer->buf);
00116 gsl_vector_set_zero (cgtrainer->gnew);
00117 gsl_vector_set_zero (cgtrainer->gold);
00118
00119 return 0;
00120 }
00121
00122
00123
00124
00125
00126
00127
00128
00129 static int
00130 gnn_conjugate_gradient_train (gnn_trainer *trainer)
00131 {
00132 double alpha;
00133 double beta;
00134 double ax, bx, cx;
00135 double fa, fb, fc;
00136 size_t s, n;
00137 gnn_conjugate_gradient *cg;
00138
00139
00140 cg = (gnn_conjugate_gradient *) trainer;
00141
00142
00143 gnn_trainer_batch_process (trainer);
00144
00145
00146 gsl_vector_memcpy (cg->gnew, gnn_trainer_batch_get_dw (trainer));
00147
00148
00149 if (cg->iteration % cg->restart == 0)
00150 {
00151
00152 gsl_vector_memcpy (cg->line->d, cg->gnew);
00153 }
00154 else
00155 {
00156
00157 beta = cg->beta (trainer);
00158 if (beta < 0.0)
00159 {
00160
00161 gsl_vector_memcpy (cg->line->d, cg->gnew);
00162 }
00163 else
00164 {
00165
00166 gsl_vector_scale (cg->line->d, beta);
00167 gsl_vector_sub (cg->line->d, cg->gnew);
00168 }
00169 }
00170
00171
00172 ax = 0.0;
00173 bx = cg->step;
00174 s = gnn_trainer_get_pattern_index (trainer);
00175 n = gnn_trainer_batch_get_size (trainer);
00176
00177 gnn_line_search_bracket (cg->line, s, n, &ax, &bx, &cx, &fa, &fb, &fc);
00178 cg->alpha (cg->line, s, n, ax, bx, cx, &alpha, cg->tol);
00179
00180
00181 gsl_vector_memcpy (cg->buf, cg->line->d);
00182 gsl_vector_scale (cg->buf, alpha);
00183
00184
00185 gsl_vector_add (cg->line->w, cg->buf);
00186
00187
00188 gnn_node_param_set (trainer->node, cg->line->w);
00189
00190
00191 gsl_vector_memcpy (cg->gold, cg->gnew);
00192
00193
00194 gnn_trainer_batch_next (trainer);
00195
00196
00197 cg->iteration++;
00198
00199 return 0;
00200 }
00201
00202
00203
00204
00205
00206
00207
00208 static void
00209 gnn_conjugate_gradient_destroy (gnn_trainer *trainer)
00210 {
00211 gnn_conjugate_gradient *cgtrainer;
00212
00213 assert (trainer != NULL);
00214
00215 cgtrainer = (gnn_conjugate_gradient *) trainer;
00216
00217 if (cgtrainer->buf != NULL)
00218 gsl_vector_free (cgtrainer->buf);
00219
00220 if (cgtrainer->gnew != NULL)
00221 gsl_vector_free (cgtrainer->gnew);
00222
00223 if (cgtrainer->gold != NULL)
00224 gsl_vector_free (cgtrainer->gold);
00225
00226 if (cgtrainer->line != NULL)
00227 gnn_line_destroy (cgtrainer->line);
00228
00229 return;
00230 }
00231
00232
00233
00234
00235
00236
00237
00238
00239
00240
00241
00242
00243
00244
00245
00246
00247
00248
00249
00250 gnn_trainer *
00251 gnn_conjugate_gradient_new (gnn_node *node,
00252 gnn_criterion *crit,
00253 gnn_dataset *data)
00254 {
00255 int status;
00256 size_t l;
00257 gnn_trainer *trainer;
00258 gnn_conjugate_gradient *cgtrainer;
00259
00260
00261 cgtrainer =
00262 (gnn_conjugate_gradient *) malloc (sizeof (gnn_conjugate_gradient));
00263 if (cgtrainer == NULL)
00264 {
00265 GSL_ERROR_VAL ("couldn't allocate memory for gnn_conjugate_gradient",
00266 GSL_ENOMEM, NULL);
00267 }
00268
00269
00270 trainer = (gnn_trainer *) cgtrainer;
00271
00272
00273 status = gnn_trainer_init (trainer,
00274 "gnn_conjugate_gradient",
00275 node,
00276 crit,
00277 data,
00278 gnn_conjugate_gradient_reset,
00279 gnn_conjugate_gradient_train,
00280 gnn_conjugate_gradient_destroy);
00281 if (status)
00282 {
00283 GSL_ERROR_VAL ("couldn't initialize gnn_conjugate_gradient",
00284 GSL_EFAILED, NULL);
00285 }
00286
00287
00288 cgtrainer->step = GNN_CONJUGATE_GRADIENT_STEP;
00289 cgtrainer->tol = GNN_CONJUGATE_GRADIENT_TOL;
00290 cgtrainer->iteration = 0;
00291 cgtrainer->restart = GNN_CONJUGATE_GRADIENT_RESTART;
00292
00293 cgtrainer->alpha = GNN_CONJUGATE_GRADIENT_ALPHA;
00294 cgtrainer->beta = GNN_CONJUGATE_GRADIENT_BETA;
00295
00296 l = gnn_node_param_get_size (node);
00297 cgtrainer->gnew = gsl_vector_alloc (l);
00298 cgtrainer->gold = gsl_vector_alloc (l);
00299 cgtrainer->buf = gsl_vector_alloc (l);
00300 cgtrainer->line = gnn_line_new (trainer->grad, NULL);
00301
00302 if ( cgtrainer->buf == NULL
00303 || cgtrainer->gnew == NULL
00304 || cgtrainer->gold == NULL )
00305 {
00306 gnn_trainer_destroy (trainer);
00307 GSL_ERROR_VAL ("couldn't allocate memory for gnn_conjugate_gradient",
00308 GSL_ENOMEM, NULL);
00309 }
00310
00311 return trainer;
00312 }
00313
00314
00315
00316
00317
00318
00319
00320
00321
00322
00323
00324
00325
00326
00327
00328
00329
00330
00331
00332
00333
00334 double
00335 gnn_conjugate_gradient_polak_ribiere (gnn_trainer *trainer)
00336 {
00337 double beta_u;
00338 double beta_l;
00339 gnn_conjugate_gradient *cgtrainer;
00340
00341
00342 cgtrainer = (gnn_conjugate_gradient *) trainer;
00343
00344
00345 gsl_vector_memcpy (cgtrainer->buf, cgtrainer->gnew);
00346 gsl_vector_sub (cgtrainer->buf, cgtrainer->gold);
00347 gsl_blas_ddot (cgtrainer->gnew, cgtrainer->buf, &beta_u);
00348 gsl_blas_ddot (cgtrainer->gold, cgtrainer->gold, &beta_l);
00349
00350 return ( beta_u / beta_l );
00351 }
00352
00353
00354
00355
00356
00357
00358
00359
00360
00361
00362
00363
00364
00365
00366
00367
00368
00369
00370
00371 double
00372 gnn_conjugate_gradient_hestenes_stiefel (gnn_trainer *trainer)
00373 {
00374 double beta_u;
00375 double beta_l;
00376 gnn_conjugate_gradient *cgtrainer;
00377
00378
00379 cgtrainer = (gnn_conjugate_gradient *) trainer;
00380
00381
00382 gsl_vector_memcpy (cgtrainer->buf, cgtrainer->gnew);
00383 gsl_vector_sub (cgtrainer->buf, cgtrainer->gold);
00384
00385 gsl_blas_ddot (cgtrainer->gnew, cgtrainer->buf, &beta_u);
00386 gsl_blas_ddot (cgtrainer->line->d, cgtrainer->buf, &beta_l);
00387
00388 return ( beta_u / beta_l );
00389 }
00390
00391
00392
00393
00394
00395
00396
00397
00398
00399
00400
00401
00402
00403
00404
00405
00406
00407
00408
00409 double
00410 gnn_conjugate_gradient_fletcher_reeves (gnn_trainer *trainer)
00411 {
00412 double beta_u;
00413 double beta_l;
00414 gnn_conjugate_gradient *cgtrainer;
00415
00416
00417 cgtrainer = (gnn_conjugate_gradient *) trainer;
00418
00419
00420 gsl_blas_ddot (cgtrainer->gnew, cgtrainer->gnew, &beta_u);
00421 gsl_blas_ddot (cgtrainer->gold, cgtrainer->gold, &beta_l);
00422
00423 return ( beta_u / beta_l );
00424 }
00425
00426
00427
00428
00429
00430
00431
00432
00433
00434 int
00435 gnn_conjugate_gradient_set_tol (gnn_trainer *trainer, double tol)
00436 {
00437 gnn_conjugate_gradient *cg;
00438
00439 assert (trainer != NULL);
00440
00441
00442 if (tol <= 0.0)
00443 {
00444 GSL_ERROR ("tolerance should be stricly greater than zero",
00445 GSL_EINVAL);
00446 }
00447
00448
00449 cg = (gnn_conjugate_gradient *) trainer;
00450 cg->tol = tol;
00451
00452 return 0;
00453 }
00454
00455
00456
00457
00458
00459
00460
00461
00462 double
00463 gnn_conjugate_gradient_get_tol (gnn_trainer *trainer)
00464 {
00465 gnn_conjugate_gradient *cg;
00466
00467 assert (trainer != NULL);
00468
00469 cg = (gnn_conjugate_gradient *) trainer;
00470
00471 return cg->tol;
00472 }
00473
00474
00475
00476
00477
00478
00479
00480
00481
00482 int
00483 gnn_conjugate_gradient_set_step (gnn_trainer *trainer, double step)
00484 {
00485 gnn_conjugate_gradient *cg;
00486
00487 assert (trainer != NULL);
00488
00489
00490 if (step <= 0.0)
00491 {
00492 GSL_ERROR ("step should be stricly greater than zero",
00493 GSL_EINVAL);
00494 }
00495
00496
00497 cg = (gnn_conjugate_gradient *) trainer;
00498 cg->step = step;
00499
00500 return 0;
00501 }
00502
00503
00504
00505
00506
00507
00508
00509
00510 double
00511 gnn_conjugate_gradient_get_step (gnn_trainer *trainer)
00512 {
00513 gnn_conjugate_gradient *cg;
00514
00515 assert (trainer != NULL);
00516
00517 cg = (gnn_conjugate_gradient *) trainer;
00518
00519 return cg->step;
00520 }
00521
00522
00523
00524
00525
00526
00527
00528
00529
00530 int
00531 gnn_conjugate_gradient_set_restart (gnn_trainer *trainer, size_t restart)
00532 {
00533 gnn_conjugate_gradient *cg;
00534
00535 assert (trainer != NULL);
00536
00537
00538 if (restart <= 0.0)
00539 {
00540 GSL_ERROR ("restart iteration should be stricly greater than zero",
00541 GSL_EINVAL);
00542 }
00543
00544
00545 cg = (gnn_conjugate_gradient *) trainer;
00546 cg->restart = restart;
00547
00548 return 0;
00549 }
00550
00551
00552
00553
00554
00555
00556
00557
00558
00559
00560
00561 size_t
00562 gnn_conjugate_gradient_get_restart (gnn_trainer *trainer)
00563 {
00564 gnn_conjugate_gradient *cg;
00565
00566 assert (trainer != NULL);
00567
00568 cg = (gnn_conjugate_gradient *) trainer;
00569
00570 return cg->restart;
00571 }
00572
00573
00574
00575
00576
00577
00578
00579
00580
00581
00582
00583
00584
00585
00586
00587
00588
00589
00590
00591
00592
00593
00594
00595 int
00596 gnn_conjugate_gradient_set_line_search (gnn_trainer *trainer,
00597 gnn_line_search_type lsearch)
00598 {
00599 gnn_conjugate_gradient *cg;
00600
00601 assert (trainer != NULL);
00602 assert (lsearch != NULL);
00603
00604
00605 cg = (gnn_conjugate_gradient *) trainer;
00606 cg->alpha = lsearch;
00607
00608 return 0;
00609 }
00610
00611
00612
00613
00614
00615
00616
00617
00618 gnn_line_search_type
00619 gnn_conjugate_gradient_get_alpha (gnn_trainer *trainer)
00620 {
00621 gnn_conjugate_gradient *cg;
00622
00623 assert (trainer != NULL);
00624
00625 cg = (gnn_conjugate_gradient *) trainer;
00626
00627 return cg->alpha;
00628 }
00629
00630
00631
00632
00633
00634
00635
00636
00637
00638
00639
00640
00641
00642
00643
00644
00645
00646
00647
00648
00649
00650 int
00651 gnn_conjugate_gradient_set_beta (gnn_trainer *trainer,
00652 gnn_conjugate_gradient_beta beta)
00653 {
00654 gnn_conjugate_gradient *cg;
00655
00656 assert (trainer != NULL);
00657 assert (beta != NULL);
00658
00659
00660 cg = (gnn_conjugate_gradient *) trainer;
00661 cg->beta = beta;
00662
00663 return 0;
00664 }
00665
00666
00667
00668
00669
00670
00671
00672
00673
00674
00675
00676 gnn_conjugate_gradient_beta
00677 gnn_conjugate_gradient_get_beta (gnn_trainer *trainer)
00678 {
00679 gnn_conjugate_gradient *cg;
00680
00681 assert (trainer != NULL);
00682
00683 cg = (gnn_conjugate_gradient *) trainer;
00684
00685 return cg->beta;
00686 }
00687