00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072 #include <math.h>
00073 #include <gsl/gsl_blas.h>
00074 #include "gnn_utilities.h"
00075 #include "gnn_bfgs.h"
00076
00077
00078
00079
00080
00081
00082
00083 static int
00084 gnn_bfgs_reset (gnn_trainer *trainer);
00085
00086 static int
00087 gnn_bfgs_iteration (gnn_bfgs *bf);
00088
00089 static int
00090 gnn_bfgs_train (gnn_trainer *trainer);
00091
00092 static void
00093 gnn_bfgs_destroy (gnn_trainer *trainer);
00094
00095
00096
00097
00098
00099
00100
00101
00102
00103
00104
00105
00106
00107
00108
00109 static int
00110 gnn_bfgs_reset (gnn_trainer *trainer)
00111 {
00112 gnn_bfgs *bf;
00113
00114 assert (trainer != NULL);
00115
00116 bf = (gnn_bfgs *) trainer;
00117
00118
00119 bf->iteration = 0;
00120
00121
00122 gsl_vector_set_zero (bf->wnew);
00123 gsl_vector_set_zero (bf->wold);
00124 gsl_vector_set_zero (bf->gnew);
00125 gsl_vector_set_zero (bf->gold);
00126 gsl_vector_set_zero (bf->u);
00127 gsl_vector_set_zero (bf->v);
00128 gsl_vector_set_zero (bf->p);
00129 gsl_vector_set_zero (bf->q);
00130
00131
00132 gsl_matrix_set_identity (bf->G);
00133
00134 return 0;
00135 }
00136
00137
00138
00139
00140
00141
00142
00143
00144 static int
00145 gnn_bfgs_iteration (gnn_bfgs *bf)
00146 {
00147 double pTv;
00148 double vTGv;
00149 double IpTv;
00150 double IvTGv;
00151
00152
00153 gsl_vector_memcpy (bf->p, bf->wnew);
00154 gsl_vector_sub (bf->p, bf->wold);
00155
00156
00157 gsl_vector_memcpy (bf->v, bf->gnew);
00158 gsl_vector_sub (bf->v, bf->gold);
00159
00160
00161 gsl_blas_dgemv (CblasNoTrans, 1.0, bf->G, bf->v, 0.0, bf->q);
00162
00163
00164 gsl_blas_ddot (bf->v, bf->q, &vTGv);
00165
00166
00167 gsl_blas_ddot (bf->p, bf->v, &pTv);
00168
00169
00170 IpTv = 1.0 / GNN_MAX (pTv, GNN_TINY);
00171 IvTGv = 1.0 / GNN_MAX (vTGv, GNN_TINY);
00172
00173
00174 gsl_vector_memcpy (bf->u, bf->p);
00175 gsl_vector_scale (bf->u, 1.0 / pTv);
00176 gsl_blas_daxpy (- 1.0 / vTGv, bf->q, bf->u);
00177
00178
00179 gsl_blas_dgemm (CblasNoTrans, CblasTrans, -IvTGv, bf->Q, bf->Q, 1, bf->G);
00180 gsl_blas_dgemm (CblasNoTrans, CblasTrans, IpTv, bf->P, bf->P, 1, bf->G);
00181 gsl_blas_dgemm (CblasNoTrans, CblasTrans, vTGv, bf->U, bf->U, 1, bf->G);
00182
00183 return 0;
00184 }
00185
00186
00187
00188
00189
00190
00191
00192
00193 static int
00194 gnn_bfgs_train (gnn_trainer *trainer)
00195 {
00196 double alpha;
00197
00198 double ax, bx, cx;
00199 double fa, fb, fc;
00200
00201 size_t s, n;
00202 gnn_bfgs *bf;
00203
00204
00205 bf = (gnn_bfgs *) trainer;
00206
00207
00208 gnn_trainer_batch_process (trainer);
00209
00210
00211 gsl_vector_memcpy (bf->gnew, gnn_trainer_batch_get_dw (trainer));
00212
00213
00214 gnn_node_param_get (trainer->node, bf->wnew);
00215
00216
00217
00218 if (bf->iteration % bf->restart == 0)
00219 {
00220 gsl_matrix_set_identity (bf->G);
00221 }
00222 else
00223 {
00224 gnn_bfgs_iteration (bf);
00225 }
00226
00227
00228 gsl_vector_memcpy (bf->wold, bf->wnew);
00229 gsl_vector_memcpy (bf->gold, bf->gnew);
00230
00231
00232 gsl_blas_dgemv (CblasNoTrans, 1.0, bf->G, bf->gnew, 0.0, bf->line->d);
00233
00234
00235 ax = 0.0;
00236 bx = bf->step;
00237 s = gnn_trainer_get_pattern_index (trainer);
00238 n = gnn_trainer_batch_get_size (trainer);
00239
00240 gnn_line_search_bracket (bf->line, s, n, &ax, &bx, &cx, &fa, &fb, &fc);
00241 bf->alpha (bf->line, s, n, ax, bx, cx, &alpha, bf->tol);
00242
00243
00244 gsl_vector_scale (bf->line->d, alpha);
00245 gsl_vector_add (bf->line->w, bf->line->d);
00246
00247
00248 gnn_node_param_set (trainer->node, bf->line->w);
00249
00250
00251 gnn_trainer_batch_next (trainer);
00252
00253
00254 bf->iteration++;
00255
00256 return 0;
00257 }
00258
00259
00260
00261
00262
00263
00264
00265 static void
00266 gnn_bfgs_destroy (gnn_trainer *trainer)
00267 {
00268 gnn_bfgs *bf;
00269
00270 assert (trainer != NULL);
00271
00272 bf = (gnn_bfgs *) trainer;
00273
00274 if (bf->G != NULL)
00275 gsl_matrix_free (bf->G);
00276 if (bf->wnew != NULL)
00277 gsl_vector_free (bf->wnew);
00278 if (bf->wold != NULL)
00279 gsl_vector_free (bf->wold);
00280 if (bf->gnew != NULL)
00281 gsl_vector_free (bf->gnew);
00282 if (bf->gold != NULL)
00283 gsl_vector_free (bf->gold);
00284 if (bf->u != NULL)
00285 gsl_vector_free (bf->u);
00286 if (bf->v != NULL)
00287 gsl_vector_free (bf->v);
00288 if (bf->p != NULL)
00289 gsl_vector_free (bf->p);
00290 if (bf->q != NULL)
00291 gsl_vector_free (bf->q);
00292 if (bf->line != NULL)
00293 gnn_line_destroy (bf->line);
00294
00295 return;
00296 }
00297
00298
00299
00300
00301
00302
00303
00304
00305
00306
00307
00308
00309
00310
00311
00312
00313
00314
00315 gnn_trainer *
00316 gnn_bfgs_new (gnn_node *node, gnn_criterion *crit, gnn_dataset *data)
00317 {
00318 int status;
00319 size_t l;
00320 gnn_trainer *trainer;
00321 gnn_bfgs *bf;
00322
00323
00324 bf = (gnn_bfgs *) malloc (sizeof (gnn_bfgs));
00325 if (bf == NULL)
00326 {
00327 GSL_ERROR_VAL ("couldn't allocate memory for gnn_bfgs",
00328 GSL_ENOMEM, NULL);
00329 }
00330
00331
00332 trainer = (gnn_trainer *) bf;
00333
00334
00335 status = gnn_trainer_init (trainer,
00336 "gnn_bfgs",
00337 node,
00338 crit,
00339 data,
00340 gnn_bfgs_reset,
00341 gnn_bfgs_train,
00342 gnn_bfgs_destroy);
00343 if (status)
00344 {
00345 GSL_ERROR_VAL ("couldn't initialize gnn_bfgs",
00346 GSL_EFAILED, NULL);
00347 }
00348
00349
00350 bf->step = GNN_BFGS_STEP;
00351 bf->tol = GNN_BFGS_TOL;
00352 bf->iteration = 0;
00353 bf->restart = GNN_BFGS_RESTART;
00354
00355 bf->alpha = GNN_BFGS_ALPHA;
00356
00357
00358 l = gnn_node_param_get_size (node);
00359 bf->G = gsl_matrix_alloc (l, l);
00360 bf->wnew = gsl_vector_alloc (l);
00361 bf->wold = gsl_vector_alloc (l);
00362 bf->gnew = gsl_vector_alloc (l);
00363 bf->gold = gsl_vector_alloc (l);
00364 bf->u = gsl_vector_alloc (l);
00365 bf->v = gsl_vector_alloc (l);
00366 bf->p = gsl_vector_alloc (l);
00367 bf->q = gsl_vector_alloc (l);
00368 bf->line = gnn_line_new (trainer->grad, NULL);
00369
00370 if ( bf->G == NULL
00371 || bf->wnew == NULL
00372 || bf->wold == NULL
00373 || bf->gnew == NULL
00374 || bf->gold == NULL
00375 || bf->u == NULL
00376 || bf->v == NULL
00377 || bf->q == NULL
00378 || bf->p == NULL
00379 || bf->line == NULL )
00380 {
00381 gnn_trainer_destroy (trainer);
00382 GSL_ERROR_VAL ("couldn't allocate memory for gnn_bfgs",
00383 GSL_ENOMEM, NULL);
00384 }
00385
00386
00387 bf->Uview = gsl_matrix_view_vector (bf->u, l, 1);
00388 bf->Vview = gsl_matrix_view_vector (bf->v, l, 1);
00389 bf->Pview = gsl_matrix_view_vector (bf->p, l, 1);
00390 bf->Qview = gsl_matrix_view_vector (bf->q, l, 1);
00391 bf->U = &(bf->Uview.matrix);
00392 bf->V = &(bf->Vview.matrix);
00393 bf->P = &(bf->Pview.matrix);
00394 bf->Q = &(bf->Qview.matrix);
00395
00396 return trainer;
00397 }
00398
00399
00400
00401
00402
00403
00404
00405
00406
00407
00408
00409 int
00410 gnn_bfgs_set_tol (gnn_trainer *trainer, double tol)
00411 {
00412 gnn_bfgs *bf;
00413
00414 assert (trainer != NULL);
00415
00416
00417 if (tol <= 0.0)
00418 {
00419 GSL_ERROR ("tolerance should be stricly greater than zero",
00420 GSL_EINVAL);
00421 }
00422
00423
00424 bf = (gnn_bfgs *) trainer;
00425 bf->tol = tol;
00426
00427 return 0;
00428 }
00429
00430
00431
00432
00433
00434
00435
00436
00437 double
00438 gnn_bfgs_get_tol (gnn_trainer *trainer)
00439 {
00440 gnn_bfgs *bf;
00441
00442 assert (trainer != NULL);
00443
00444 bf = (gnn_bfgs *) trainer;
00445
00446 return bf->tol;
00447 }
00448
00449
00450
00451
00452
00453
00454
00455
00456
00457 int
00458 gnn_bfgs_set_step (gnn_trainer *trainer, double step)
00459 {
00460 gnn_bfgs *bf;
00461
00462 assert (trainer != NULL);
00463
00464
00465 if (step <= 0.0)
00466 {
00467 GSL_ERROR ("step should be stricly greater than zero",
00468 GSL_EINVAL);
00469 }
00470
00471
00472 bf = (gnn_bfgs *) trainer;
00473 bf->step = step;
00474
00475 return 0;
00476 }
00477
00478
00479
00480
00481
00482
00483
00484
00485 double
00486 gnn_bfgs_get_step (gnn_trainer *trainer)
00487 {
00488 gnn_bfgs *bf;
00489
00490 assert (trainer != NULL);
00491
00492 bf = (gnn_bfgs *) trainer;
00493
00494 return bf->step;
00495 }
00496
00497
00498
00499
00500
00501
00502
00503
00504
00505 int
00506 gnn_bfgs_set_restart (gnn_trainer *trainer, size_t restart)
00507 {
00508 gnn_bfgs *bf;
00509
00510 assert (trainer != NULL);
00511
00512
00513 if (restart <= 0.0)
00514 {
00515 GSL_ERROR ("restart iteration should be stricly greater than zero",
00516 GSL_EINVAL);
00517 }
00518
00519
00520 bf = (gnn_bfgs *) trainer;
00521 bf->restart = restart;
00522
00523 return 0;
00524 }
00525
00526
00527
00528
00529
00530
00531
00532
00533
00534
00535
00536 size_t
00537 gnn_bfgs_get_restart (gnn_trainer *trainer)
00538 {
00539 gnn_bfgs *bf;
00540
00541 assert (trainer != NULL);
00542
00543 bf = (gnn_bfgs *) trainer;
00544
00545 return bf->restart;
00546 }
00547
00548
00549
00550
00551
00552
00553
00554
00555
00556
00557
00558
00559
00560
00561
00562
00563
00564
00565
00566
00567
00568
00569 int
00570 gnn_bfgs_set_line_search (gnn_trainer *trainer, gnn_line_search_type lsearch)
00571 {
00572 gnn_bfgs *bf;
00573
00574 assert (trainer != NULL);
00575 assert (lsearch != NULL);
00576
00577
00578 bf = (gnn_bfgs *) trainer;
00579 bf->alpha = lsearch;
00580
00581 return 0;
00582 }
00583
00584
00585
00586
00587
00588
00589
00590
00591 gnn_line_search_type
00592 gnn_bfgs_get_alpha (gnn_trainer *trainer)
00593 {
00594 gnn_bfgs *bf;
00595
00596 assert (trainer != NULL);
00597
00598 bf = (gnn_bfgs *) trainer;
00599
00600 return bf->alpha;
00601 }
00602
00603
00604