00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072 #include <math.h>
00073 #include <gsl/gsl_blas.h>
00074 #include "gnn_utilities.h"
00075 #include "gnn_lmbfgs.h"
00076
00077
00078
00079
00080
00081
00082
00083 static int
00084 gnn_lmbfgs_reset (gnn_trainer *trainer);
00085
00086 static int
00087 gnn_lmbfgs_iteration (gnn_lmbfgs *bf, double *A, double *B);
00088
00089 static int
00090 gnn_lmbfgs_train (gnn_trainer *trainer);
00091
00092 static void
00093 gnn_lmbfgs_destroy (gnn_trainer *trainer);
00094
00095
00096
00097
00098
00099
00100
00101
00102
00103
00104
00105
00106
00107
00108
00109 static int
00110 gnn_lmbfgs_reset (gnn_trainer *trainer)
00111 {
00112 gnn_lmbfgs *bf;
00113
00114 assert (trainer != NULL);
00115
00116 bf = (gnn_lmbfgs *) trainer;
00117
00118
00119 bf->iteration = 0;
00120
00121
00122 gsl_vector_set_zero (bf->wnew);
00123 gsl_vector_set_zero (bf->wold);
00124 gsl_vector_set_zero (bf->gnew);
00125 gsl_vector_set_zero (bf->gold);
00126 gsl_vector_set_zero (bf->v);
00127 gsl_vector_set_zero (bf->p);
00128
00129 return 0;
00130 }
00131
00132
00133
00134
00135
00136
00137
00138
00139
00140
00141 static int
00142 gnn_lmbfgs_iteration (gnn_lmbfgs *bf, double *A, double *B)
00143 {
00144 double pTv;
00145 double pTg;
00146 double vTg;
00147 double vTv;
00148
00149
00150 gsl_vector_memcpy (bf->p, bf->wnew);
00151 gsl_vector_sub (bf->p, bf->wold);
00152
00153
00154 gsl_vector_memcpy (bf->v, bf->gnew);
00155 gsl_vector_sub (bf->v, bf->gold);
00156
00157
00158 gsl_blas_ddot (bf->p, bf->v, &pTv);
00159 pTv = GNN_MAX (pTv, GNN_TINY);
00160
00161
00162 gsl_blas_ddot (bf->v, bf->v, &vTv);
00163
00164
00165 gsl_blas_ddot (bf->p, bf->gnew, &vTv);
00166
00167
00168 gsl_blas_ddot (bf->v, bf->gnew, &vTv);
00169
00170
00171 *A = -(1.0 + vTv / pTv) * pTg / pTv + vTg / pTv;
00172 *B = pTg / pTv;
00173
00174 return 0;
00175 }
00176
00177
00178
00179
00180
00181
00182
00183
00184 static int
00185 gnn_lmbfgs_train (gnn_trainer *trainer)
00186 {
00187 double A;
00188 double B;
00189 double alpha;
00190
00191 double ax, bx, cx;
00192 double fa, fb, fc;
00193
00194 size_t s, n;
00195 gnn_lmbfgs *bf;
00196
00197
00198 bf = (gnn_lmbfgs *) trainer;
00199
00200
00201 gnn_trainer_batch_process (trainer);
00202
00203
00204 gsl_vector_memcpy (bf->gnew, gnn_trainer_batch_get_dw (trainer));
00205
00206
00207 gnn_node_param_get (trainer->node, bf->wnew);
00208
00209
00210 if (bf->iteration % bf->restart == 0)
00211 {
00212
00213 A = 0.0;
00214 B = 0.0;
00215 }
00216 else
00217 {
00218
00219 gnn_lmbfgs_iteration (bf, &A, &B);
00220 }
00221
00222
00223 gsl_vector_memcpy (bf->wold, bf->wnew);
00224 gsl_vector_memcpy (bf->gold, bf->gnew);
00225
00226
00227 gsl_vector_memcpy (bf->line->d, bf->gnew);
00228 gsl_vector_scale (bf->line->d, -1.0);
00229 gsl_blas_daxpy (A, bf->p, bf->line->d);
00230 gsl_blas_daxpy (B, bf->v, bf->line->d);
00231
00232
00233 ax = 0.0;
00234 bx = bf->step;
00235 s = gnn_trainer_get_pattern_index (trainer);
00236 n = gnn_trainer_batch_get_size (trainer);
00237
00238 gnn_line_search_bracket (bf->line, s, n, &ax, &bx, &cx, &fa, &fb, &fc);
00239 bf->alpha (bf->line, s, n, ax, bx, cx, &alpha, bf->tol);
00240
00241
00242 gsl_vector_scale (bf->line->d, alpha);
00243 gsl_vector_add (bf->line->w, bf->line->d);
00244
00245
00246 gnn_node_param_set (trainer->node, bf->line->w);
00247
00248
00249 gnn_trainer_batch_next (trainer);
00250
00251
00252 bf->iteration++;
00253
00254 return 0;
00255 }
00256
00257
00258
00259
00260
00261
00262
00263 static void
00264 gnn_lmbfgs_destroy (gnn_trainer *trainer)
00265 {
00266 gnn_lmbfgs *bf;
00267
00268 assert (trainer != NULL);
00269
00270 bf = (gnn_lmbfgs *) trainer;
00271
00272 if (bf->wnew != NULL)
00273 gsl_vector_free (bf->wnew);
00274 if (bf->wold != NULL)
00275 gsl_vector_free (bf->wold);
00276 if (bf->gnew != NULL)
00277 gsl_vector_free (bf->gnew);
00278 if (bf->gold != NULL)
00279 gsl_vector_free (bf->gold);
00280 if (bf->v != NULL)
00281 gsl_vector_free (bf->v);
00282 if (bf->p != NULL)
00283 gsl_vector_free (bf->p);
00284 if (bf->line != NULL)
00285 gnn_line_destroy (bf->line);
00286
00287 return;
00288 }
00289
00290
00291
00292
00293
00294
00295
00296
00297
00298
00299
00300
00301
00302
00303
00304
00305
00306
00307 gnn_trainer *
00308 gnn_lmbfgs_new (gnn_node *node, gnn_criterion *crit, gnn_dataset *data)
00309 {
00310 int status;
00311 size_t l;
00312 gnn_trainer *trainer;
00313 gnn_lmbfgs *bf;
00314
00315
00316 bf = (gnn_lmbfgs *) malloc (sizeof (gnn_lmbfgs));
00317 if (bf == NULL)
00318 {
00319 GSL_ERROR_VAL ("couldn't allocate memory for gnn_lmbfgs",
00320 GSL_ENOMEM, NULL);
00321 }
00322
00323
00324 trainer = (gnn_trainer *) bf;
00325
00326
00327 status = gnn_trainer_init (trainer,
00328 "gnn_lmbfgs",
00329 node,
00330 crit,
00331 data,
00332 gnn_lmbfgs_reset,
00333 gnn_lmbfgs_train,
00334 gnn_lmbfgs_destroy);
00335 if (status)
00336 {
00337 GSL_ERROR_VAL ("couldn't initialize gnn_lmbfgs",
00338 GSL_EFAILED, NULL);
00339 }
00340
00341
00342 bf->step = GNN_LMBFGS_STEP;
00343 bf->tol = GNN_LMBFGS_TOL;
00344 bf->iteration = 0;
00345 bf->restart = GNN_LMBFGS_RESTART;
00346
00347 bf->alpha = GNN_LMBFGS_ALPHA;
00348
00349
00350 l = gnn_node_param_get_size (node);
00351 bf->wnew = gsl_vector_alloc (l);
00352 bf->wold = gsl_vector_alloc (l);
00353 bf->gnew = gsl_vector_alloc (l);
00354 bf->gold = gsl_vector_alloc (l);
00355 bf->v = gsl_vector_alloc (l);
00356 bf->p = gsl_vector_alloc (l);
00357 bf->line = gnn_line_new (trainer->grad, NULL);
00358
00359 if ( bf->wnew == NULL
00360 || bf->wold == NULL
00361 || bf->gnew == NULL
00362 || bf->gold == NULL
00363 || bf->v == NULL
00364 || bf->p == NULL
00365 || bf->line == NULL )
00366 {
00367 gnn_trainer_destroy (trainer);
00368 GSL_ERROR_VAL ("couldn't allocate memory for gnn_lmbfgs",
00369 GSL_ENOMEM, NULL);
00370 }
00371
00372 return trainer;
00373 }
00374
00375
00376
00377
00378
00379
00380
00381
00382
00383
00384
00385 int
00386 gnn_lmbfgs_set_tol (gnn_trainer *trainer, double tol)
00387 {
00388 gnn_lmbfgs *bf;
00389
00390 assert (trainer != NULL);
00391
00392
00393 if (tol <= 0.0)
00394 {
00395 GSL_ERROR ("tolerance should be stricly greater than zero",
00396 GSL_EINVAL);
00397 }
00398
00399
00400 bf = (gnn_lmbfgs *) trainer;
00401 bf->tol = tol;
00402
00403 return 0;
00404 }
00405
00406
00407
00408
00409
00410
00411
00412
00413 double
00414 gnn_lmbfgs_get_tol (gnn_trainer *trainer)
00415 {
00416 gnn_lmbfgs *bf;
00417
00418 assert (trainer != NULL);
00419
00420 bf = (gnn_lmbfgs *) trainer;
00421
00422 return bf->tol;
00423 }
00424
00425
00426
00427
00428
00429
00430
00431
00432
00433 int
00434 gnn_lmbfgs_set_step (gnn_trainer *trainer, double step)
00435 {
00436 gnn_lmbfgs *bf;
00437
00438 assert (trainer != NULL);
00439
00440
00441 if (step <= 0.0)
00442 {
00443 GSL_ERROR ("step should be stricly greater than zero",
00444 GSL_EINVAL);
00445 }
00446
00447
00448 bf = (gnn_lmbfgs *) trainer;
00449 bf->step = step;
00450
00451 return 0;
00452 }
00453
00454
00455
00456
00457
00458
00459
00460
00461 double
00462 gnn_lmbfgs_get_step (gnn_trainer *trainer)
00463 {
00464 gnn_lmbfgs *bf;
00465
00466 assert (trainer != NULL);
00467
00468 bf = (gnn_lmbfgs *) trainer;
00469
00470 return bf->step;
00471 }
00472
00473
00474
00475
00476
00477
00478
00479
00480
00481 int
00482 gnn_lmbfgs_set_restart (gnn_trainer *trainer, size_t restart)
00483 {
00484 gnn_lmbfgs *bf;
00485
00486 assert (trainer != NULL);
00487
00488
00489 if (restart <= 0.0)
00490 {
00491 GSL_ERROR ("restart iteration should be stricly greater than zero",
00492 GSL_EINVAL);
00493 }
00494
00495
00496 bf = (gnn_lmbfgs *) trainer;
00497 bf->restart = restart;
00498
00499 return 0;
00500 }
00501
00502
00503
00504
00505
00506
00507
00508
00509
00510
00511
00512 size_t
00513 gnn_lmbfgs_get_restart (gnn_trainer *trainer)
00514 {
00515 gnn_lmbfgs *bf;
00516
00517 assert (trainer != NULL);
00518
00519 bf = (gnn_lmbfgs *) trainer;
00520
00521 return bf->restart;
00522 }
00523
00524
00525
00526
00527
00528
00529
00530
00531
00532
00533
00534
00535
00536
00537
00538
00539
00540
00541
00542
00543
00544
00545 int
00546 gnn_lmbfgs_set_line_search (gnn_trainer *trainer, gnn_line_search_type lsearch)
00547 {
00548 gnn_lmbfgs *bf;
00549
00550 assert (trainer != NULL);
00551 assert (lsearch != NULL);
00552
00553
00554 bf = (gnn_lmbfgs *) trainer;
00555 bf->alpha = lsearch;
00556
00557 return 0;
00558 }
00559
00560
00561
00562
00563
00564
00565
00566
00567 gnn_line_search_type
00568 gnn_lmbfgs_get_alpha (gnn_trainer *trainer)
00569 {
00570 gnn_lmbfgs *bf;
00571
00572 assert (trainer != NULL);
00573
00574 bf = (gnn_lmbfgs *) trainer;
00575
00576 return bf->alpha;
00577 }
00578
00579
00580