Main Page   Modules   Data Structures   File List   Data Fields   Globals   Related Pages  

gnn_bfgs.c

Go to the documentation of this file.
00001 /***************************************************************************
00002  *  @file gnn_bfgs.c
00003  *  @brief gnn_bfgs Implementation.
00004  *
00005  *  @date   : 05-10-03 01:05
00006  *  @author : Pedro Ortega C. <peortega@dcc.uchile.cl>
00007  *  Copyright  2003  Pedro Ortega C.
00008  ****************************************************************************/
00009 /*
00010  *  This program is free software; you can redistribute it and/or modify
00011  *  it under the terms of the GNU General Public License as published by
00012  *  the Free Software Foundation; either version 2 of the License, or
00013  *  (at your option) any later version.
00014  *
00015  *  This program is distributed in the hope that it will be useful,
00016  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
00017  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00018  *  GNU Library General Public License for more details.
00019  *
00020  *  You should have received a copy of the GNU General Public License
00021  *  along with this program; if not, write to the Free Software
00022  *  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
00023  */
00024 
00025 
00026 
00027 /**
00028  * @defgroup gnn_bfgs_doc gnn_bfgs : Broyden-Fletcher-Goldfarb-Shanno Algorithm.
00029  * @ingroup gnn_trainer_doc
00030  *
00031  * This trainer implements the Broyden-Fletcher-Goldfarb-Shanno (BFGS)
00032  * optimization algorithm. It is a so-called \em Quasi-Newton algorithm.
00033  * The parameters are updated using the formula:
00034  *
00035  * \f[ w^{(\tau + 1)} = w^{(\tau)} + \alpha^{\tau} G^{(\tau)} g^{(\tau)} \f]
00036  *
00037  * where \f$G^{(\tau)}\f$ is an approximation to the inverse Hessian matrix,
00038  * which is built step by step using the formula:
00039  *
00040  * \f[ G^{ (\tau+1) } =
00041  *         G^{ (\tau + 1) }
00042  *         - \frac{ qq^T }{ v^T G^{(\tau)} v }
00043  *         + \frac{ pp^T }{ p^Tv }
00044  *         + ( v^T G^{(\tau)} v ) uu^T
00045  * \f]
00046  *
00047  * where the following vectors have been defined:
00048  *
00049  * \f[ p = w^{(\tau + 1)} - w^{(\tau)}                  \f]
00050  * \f[ q = G^{(\tau)} v                                 \f]
00051  * \f[ v = g^{(\tau + 1)} - g^{(\tau)}                  \f]
00052  * \f[ u = \frac{p}{p^T v} - \frac{q}{v^T G^{(\tau)} v} \f]
00053  *
00054  * and \f$g^{(\tau)}\f$ is the \f$\tau\f$-th gradient
00055  * \f$\frac{\partial E}{\partial w}\f$. The step size \f$\alpha^{(\tau)}\f$
00056  * is found using a line-minimization, which can be one of those available
00057  * in \ref gnn_line_search.
00058  *
00059  * Although theoretically well founded and very robust, the BFGS algorithm
00060  * is very expensive, both in time and space. The approximation of the
00061  * Hessian matrix is kept in memory, which is of size \f$O(l^2)\f$ where
00062  * \f$l\f$ are the number of parameters, leading to a prohibitive amount
00063  * of storage requirement for greater Gradient Machines.
00064  */
00065 
00066 
00067 
00068 /******************************************/
00069 /* Include Files                          */
00070 /******************************************/
00071 
00072 #include <math.h>
00073 #include <gsl/gsl_blas.h>
00074 #include "gnn_utilities.h"
00075 #include "gnn_bfgs.h"
00076 
00077 
00078 
00079 /******************************************/
00080 /* Static Declaration                     */
00081 /******************************************/
00082 
00083 static int
00084 gnn_bfgs_reset (gnn_trainer *trainer);
00085 
00086 static int
00087 gnn_bfgs_iteration (gnn_bfgs *bf);
00088 
00089 static int
00090 gnn_bfgs_train (gnn_trainer *trainer);
00091 
00092 static void
00093 gnn_bfgs_destroy (gnn_trainer *trainer);
00094 
00095 
00096 
00097 
00098 /******************************************/
00099 /* Static Implementation                  */
00100 /******************************************/
00101 
00102 /**
00103  * @brief The trainer's "reset" implementation.
00104  * @ingroup gnn_bfgs_doc
00105  *
00106  * @param  trainer A pointer to a \ref gnn_bfgs.
00107  * @return Returns 0 if suceeded.
00108  */
00109 static int
00110 gnn_bfgs_reset (gnn_trainer *trainer)
00111 {
00112     gnn_bfgs *bf;
00113 
00114     assert (trainer != NULL);
00115 
00116     bf = (gnn_bfgs *) trainer;
00117 
00118     /* reset iteration counter */
00119     bf->iteration = 0;
00120 
00121     /* clear optimization information */
00122     gsl_vector_set_zero (bf->wnew);
00123     gsl_vector_set_zero (bf->wold);
00124     gsl_vector_set_zero (bf->gnew);
00125     gsl_vector_set_zero (bf->gold);
00126     gsl_vector_set_zero (bf->u);
00127     gsl_vector_set_zero (bf->v);
00128     gsl_vector_set_zero (bf->p);
00129     gsl_vector_set_zero (bf->q);
00130 
00131     /* initialize Hessian approximation */
00132     gsl_matrix_set_identity (bf->G);
00133 
00134     return 0;
00135 }
00136 
00137 /**
00138  * @brief Computes the step-wise approximation of the Hessian matrix.
00139  * @ingroup gnn_bfgs_doc
00140  *
00141  * @param  trainer A pointer to a \ref gnn_bfgs.
00142  * @return Returns 0 if suceeded.
00143  */
00144 static int
00145 gnn_bfgs_iteration (gnn_bfgs *bf)
00146 {
00147     double pTv;
00148     double vTGv;
00149     double IpTv;
00150     double IvTGv;
00151 
00152     /* compute p vector */
00153     gsl_vector_memcpy (bf->p, bf->wnew);
00154     gsl_vector_sub (bf->p, bf->wold);
00155 
00156     /* compute v vector */
00157     gsl_vector_memcpy (bf->v, bf->gnew);
00158     gsl_vector_sub (bf->v, bf->gold);
00159 
00160     /* compute q vector */
00161     gsl_blas_dgemv (CblasNoTrans, 1.0, bf->G, bf->v, 0.0, bf->q);
00162 
00163     /* compute vTGv */
00164     gsl_blas_ddot (bf->v, bf->q, &vTGv);
00165 
00166     /* compute pTv */
00167     gsl_blas_ddot (bf->p, bf->v, &pTv);
00168     
00169     /* compute inverses */
00170     IpTv = 1.0 / GNN_MAX (pTv, GNN_TINY);
00171     IvTGv = 1.0 / GNN_MAX (vTGv, GNN_TINY);
00172 
00173     /* compute u vector */
00174     gsl_vector_memcpy (bf->u, bf->p);
00175     gsl_vector_scale (bf->u, 1.0 / pTv);
00176     gsl_blas_daxpy (- 1.0 / vTGv, bf->q, bf->u);
00177 
00178     /* compute Hessian approximation */
00179     gsl_blas_dgemm (CblasNoTrans, CblasTrans, -IvTGv, bf->Q, bf->Q, 1, bf->G);
00180     gsl_blas_dgemm (CblasNoTrans, CblasTrans,   IpTv, bf->P, bf->P, 1, bf->G);
00181     gsl_blas_dgemm (CblasNoTrans, CblasTrans,   vTGv, bf->U, bf->U, 1, bf->G);
00182 
00183     return 0;
00184 }
00185 
00186 /**
00187  * @brief The trainer's "train" implementation.
00188  * @ingroup gnn_bfgs_descent_doc
00189  *
00190  * @param  trainer A pointer to a \ref gnn_bfgs.
00191  * @return Returns 0 if succeeded.
00192  */
00193 static int
00194 gnn_bfgs_train (gnn_trainer *trainer)
00195 {
00196     double alpha;
00197 
00198     double ax, bx, cx;
00199     double fa, fb, fc;
00200 
00201     size_t s, n;
00202     gnn_bfgs *bf;
00203 
00204     /* get view */
00205     bf = (gnn_bfgs *) trainer;
00206 
00207     /* process minibatch */
00208     gnn_trainer_batch_process (trainer);
00209 
00210     /* copy gradient */
00211     gsl_vector_memcpy (bf->gnew, gnn_trainer_batch_get_dw (trainer));
00212 
00213     /* copy current parameter vector */
00214     gnn_node_param_get (trainer->node, bf->wnew);
00215 
00216     /* compute new search direction and perform one */
00217     /* step of the Hessian matrix approximation     */
00218     if (bf->iteration % bf->restart == 0)
00219     {
00220         gsl_matrix_set_identity (bf->G);
00221     }
00222     else
00223     {
00224         gnn_bfgs_iteration (bf);
00225     }
00226 
00227     /* prepare for next iteration */
00228     gsl_vector_memcpy (bf->wold, bf->wnew);
00229     gsl_vector_memcpy (bf->gold, bf->gnew);
00230 
00231     /* build new search direction */
00232     gsl_blas_dgemv (CblasNoTrans, 1.0, bf->G, bf->gnew, 0.0, bf->line->d);
00233 
00234     /* perform line search over the current minibatch's patterns */
00235     ax = 0.0;
00236     bx = bf->step;
00237     s = gnn_trainer_get_pattern_index (trainer);
00238     n = gnn_trainer_batch_get_size (trainer);
00239 
00240     gnn_line_search_bracket (bf->line, s, n, &ax, &bx, &cx, &fa, &fb, &fc);
00241     bf->alpha (bf->line, s, n, ax, bx, cx, &alpha, bf->tol);
00242 
00243     /* build new origin */
00244     gsl_vector_scale (bf->line->d, alpha);
00245     gsl_vector_add (bf->line->w, bf->line->d);
00246 
00247     /* update parameters */
00248     gnn_node_param_set (trainer->node, bf->line->w);
00249 
00250     /* move to next minibatch */
00251     gnn_trainer_batch_next (trainer);
00252 
00253     /* update iteration counter */
00254     bf->iteration++;
00255 
00256     return 0;
00257 }
00258 
00259 /**
00260  * @brief The trainers "destroy" implementation.
00261  * @ingroup gnn_bfgs_doc
00262  *
00263  * @param trainer A pointer to a \ref gnn_bfgs.
00264  */
00265 static void
00266 gnn_bfgs_destroy (gnn_trainer *trainer)
00267 {
00268     gnn_bfgs *bf;
00269 
00270     assert (trainer != NULL);
00271 
00272     bf = (gnn_bfgs *) trainer;
00273 
00274     if (bf->G != NULL)
00275         gsl_matrix_free (bf->G);
00276     if (bf->wnew != NULL)
00277         gsl_vector_free (bf->wnew);
00278     if (bf->wold != NULL)
00279         gsl_vector_free (bf->wold);
00280     if (bf->gnew != NULL)
00281         gsl_vector_free (bf->gnew);
00282     if (bf->gold != NULL)
00283         gsl_vector_free (bf->gold);
00284     if (bf->u != NULL)
00285         gsl_vector_free (bf->u);
00286     if (bf->v != NULL)
00287         gsl_vector_free (bf->v);
00288     if (bf->p != NULL)
00289         gsl_vector_free (bf->p);
00290     if (bf->q != NULL)
00291         gsl_vector_free (bf->q);
00292     if (bf->line != NULL)
00293         gnn_line_destroy (bf->line);
00294 
00295     return;
00296 }
00297 
00298 
00299 
00300 /******************************************/
00301 /* Public Interface                       */
00302 /******************************************/
00303 
00304 /**
00305  * @brief Creates a new BFGS trainer.
00306  * @ingroup gnn_bfgs_doc
00307  *
00308  * This function creates a new BFGS trainer (\ref gnn_bfgs).
00309  *
00310  * @param  node A pointer to a \ref gnn_node.
00311  * @param  crit A pointer to a \ref gnn_criterion.
00312  * @param  data A pointer to a \ref gnn_dataset.
00313  * @return Returns a pointer to a new \ref gnn_bfgs trainer.
00314  */
00315 gnn_trainer *
00316 gnn_bfgs_new (gnn_node *node, gnn_criterion *crit, gnn_dataset *data)
00317 {
00318     int status;
00319     size_t l;
00320     gnn_trainer *trainer;
00321     gnn_bfgs *bf;
00322 
00323     /* allocate memory for the trainer */
00324     bf = (gnn_bfgs *) malloc (sizeof (gnn_bfgs));
00325     if (bf == NULL)
00326     {
00327         GSL_ERROR_VAL ("couldn't allocate memory for gnn_bfgs",
00328                        GSL_ENOMEM, NULL);
00329     }
00330 
00331     /* get view as gnn_trainer */
00332     trainer = (gnn_trainer *) bf;
00333 
00334     /* initialize */
00335     status = gnn_trainer_init (trainer,
00336                                "gnn_bfgs",
00337                                node,
00338                                crit,
00339                                data,
00340                                gnn_bfgs_reset,
00341                                gnn_bfgs_train,
00342                                gnn_bfgs_destroy);
00343     if (status)
00344     {
00345         GSL_ERROR_VAL ("couldn't initialize gnn_bfgs",
00346                        GSL_EFAILED, NULL);
00347     }
00348 
00349     /* set fields */
00350     bf->step      = GNN_BFGS_STEP;
00351     bf->tol       = GNN_BFGS_TOL;
00352     bf->iteration = 0;
00353     bf->restart   = GNN_BFGS_RESTART;
00354 
00355     bf->alpha = GNN_BFGS_ALPHA;
00356 
00357     /* allocate memory for all needed buffers */
00358     l = gnn_node_param_get_size (node);
00359     bf->G    = gsl_matrix_alloc (l, l);
00360     bf->wnew = gsl_vector_alloc (l);
00361     bf->wold = gsl_vector_alloc (l);
00362     bf->gnew = gsl_vector_alloc (l);
00363     bf->gold = gsl_vector_alloc (l);
00364     bf->u    = gsl_vector_alloc (l);
00365     bf->v    = gsl_vector_alloc (l);
00366     bf->p    = gsl_vector_alloc (l);
00367     bf->q    = gsl_vector_alloc (l);
00368     bf->line = gnn_line_new (trainer->grad, NULL);
00369 
00370     if (   bf->G    == NULL
00371         || bf->wnew == NULL
00372         || bf->wold == NULL
00373         || bf->gnew == NULL
00374         || bf->gold == NULL
00375         || bf->u    == NULL
00376         || bf->v    == NULL
00377         || bf->q    == NULL
00378         || bf->p    == NULL
00379         || bf->line == NULL )
00380     {
00381         gnn_trainer_destroy (trainer);
00382         GSL_ERROR_VAL ("couldn't allocate memory for gnn_bfgs",
00383                        GSL_ENOMEM, NULL);
00384     }
00385 
00386     /* build needed matrix views of the vectors */
00387     bf->Uview = gsl_matrix_view_vector (bf->u, l, 1);
00388     bf->Vview = gsl_matrix_view_vector (bf->v, l, 1);
00389     bf->Pview = gsl_matrix_view_vector (bf->p, l, 1);
00390     bf->Qview = gsl_matrix_view_vector (bf->q, l, 1);
00391     bf->U     = &(bf->Uview.matrix);
00392     bf->V     = &(bf->Vview.matrix);
00393     bf->P     = &(bf->Pview.matrix);
00394     bf->Q     = &(bf->Qview.matrix);
00395 
00396     return trainer;
00397 }
00398 
00399 
00400 
00401 /**
00402  * @brief Sets the precision tolerance for the line search procedure.
00403  * @ingroup gnn_bfgs_doc
00404  *
00405  * @param trainer A pointer to a \ref gnn_bfgs.
00406  * @param tol A stricly positive real value.
00407  * @return Returns 0 if succeeded.
00408  */
00409 int
00410 gnn_bfgs_set_tol (gnn_trainer *trainer, double tol)
00411 {
00412     gnn_bfgs *bf;
00413 
00414     assert (trainer != NULL);
00415 
00416     /* check value */
00417     if (tol <= 0.0)
00418     {
00419         GSL_ERROR ("tolerance should be stricly greater than zero",
00420                    GSL_EINVAL);
00421     }
00422 
00423     /* set value */
00424     bf = (gnn_bfgs *) trainer;
00425     bf->tol = tol;
00426 
00427     return 0;
00428 }
00429 
00430 /**
00431  * @brief Gets the tolerance for the line search procedure.
00432  * @ingroup gnn_bfgs_doc
00433  *
00434  * @param trainer A pointer to a \ref gnn_bfgs.
00435  * @return Returns the tolerance's value.
00436  */
00437 double
00438 gnn_bfgs_get_tol (gnn_trainer *trainer)
00439 {
00440     gnn_bfgs *bf;
00441 
00442     assert (trainer != NULL);
00443 
00444     bf = (gnn_bfgs *) trainer;
00445 
00446     return bf->tol;
00447 }
00448 
00449 /**
00450  * @brief Sets the initial step for the interval bracketing procedure.
00451  * @ingroup gnn_bfgs_doc
00452  *
00453  * @param trainer A pointer to a \ref gnn_bfgs.
00454  * @param step A stricly positive real value.
00455  * @return Returns 0 if succeeded.
00456  */
00457 int
00458 gnn_bfgs_set_step (gnn_trainer *trainer, double step)
00459 {
00460     gnn_bfgs *bf;
00461 
00462     assert (trainer != NULL);
00463 
00464     /* check value */
00465     if (step <= 0.0)
00466     {
00467         GSL_ERROR ("step should be stricly greater than zero",
00468                    GSL_EINVAL);
00469     }
00470 
00471     /* set value */
00472     bf = (gnn_bfgs *) trainer;
00473     bf->step = step;
00474 
00475     return 0;
00476 }
00477 
00478 /**
00479  * @brief Gets the initial step for the interval bracketing procedure.
00480  * @ingroup gnn_bfgs_doc
00481  *
00482  * @param trainer A pointer to a \ref gnn_bfgs.
00483  * @return The trainer's internal step.
00484  */
00485 double
00486 gnn_bfgs_get_step (gnn_trainer *trainer)
00487 {
00488     gnn_bfgs *bf;
00489 
00490     assert (trainer != NULL);
00491 
00492     bf = (gnn_bfgs *) trainer;
00493 
00494     return bf->step;
00495 }
00496 
00497 /**
00498  * @brief Sets the number of iterations before restarting.
00499  * @ingroup gnn_bfgs_doc
00500  *
00501  * @param trainer A pointer to a \ref gnn_bfgs.
00502  * @param restart A stricly positive integer.
00503  * @return Returns 0 if succeeded.
00504  */
00505 int
00506 gnn_bfgs_set_restart (gnn_trainer *trainer, size_t restart)
00507 {
00508     gnn_bfgs *bf;
00509 
00510     assert (trainer != NULL);
00511 
00512     /* check value */
00513     if (restart <= 0.0)
00514     {
00515         GSL_ERROR ("restart iteration should be stricly greater than zero",
00516                    GSL_EINVAL);
00517     }
00518 
00519     /* set value */
00520     bf = (gnn_bfgs *) trainer;
00521     bf->restart = restart;
00522 
00523     return 0;
00524 }
00525 
00526 /**
00527  * @brief Gets the number of iterations before reinitializing the direction.
00528  * @ingroup gnn_bfgs_doc
00529  *
00530  * This function returns the number of iterations executed by the BFGS
00531  * trainer before reinitializing the search direction.
00532  *
00533  * @param trainer A pointer to a \ref gnn_bfgs.
00534  * @return Returns the number of iterations.
00535  */
00536 size_t
00537 gnn_bfgs_get_restart (gnn_trainer *trainer)
00538 {
00539     gnn_bfgs *bf;
00540 
00541     assert (trainer != NULL);
00542 
00543     bf = (gnn_bfgs *) trainer;
00544 
00545     return bf->restart;
00546 }
00547 
00548 /**
00549  * @brief Sets the line search procedure.
00550  * @ingroup gnn_bfgs_doc
00551  *
00552  * This function sets a new line search procedure used by the BFGS trainer.
00553  *
00554  * \code
00555  * gnn_trainer *trainer;
00556  * trainer = gnn_bfgs_new (node, crit, data);
00557  *
00558  * // use the Golden-Section line search
00559  * gnn_bfgs_set_line_search (trainer, gnn_line_search_golden);
00560  * \endcode
00561  *
00562  * Please refer to (\ref gnn_line_search_doc) for the available line search
00563  * procedures.
00564  *
00565  * @param trainer A pointer to a \ref gnn_bfgs.
00566  * @param lsearch A pointer to a line search procedure.
00567  * @return Returns 0 if succeeded.
00568  */
00569 int
00570 gnn_bfgs_set_line_search (gnn_trainer *trainer, gnn_line_search_type lsearch)
00571 {
00572     gnn_bfgs *bf;
00573 
00574     assert (trainer != NULL);
00575     assert (lsearch != NULL);
00576 
00577     /* set value */
00578     bf = (gnn_bfgs *) trainer;
00579     bf->alpha = lsearch;
00580 
00581     return 0;
00582 }
00583 
00584 /**
00585  * @brief Gets the installed line search procedure.
00586  * @ingroup gnn_bfgs_doc
00587  *
00588  * @param trainer A pointer to a \ref gnn_bfgs.
00589  * @return Returns a pointer to the installed line-search procedure.
00590  */
00591 gnn_line_search_type
00592 gnn_bfgs_get_alpha (gnn_trainer *trainer)
00593 {
00594     gnn_bfgs *bf;
00595 
00596     assert (trainer != NULL);
00597 
00598     bf = (gnn_bfgs *) trainer;
00599 
00600     return bf->alpha;
00601 }
00602 
00603 
00604 

Generated on Sun Jun 13 20:50:11 2004 for libgnn Gradient Retropropagation Machine Library by doxygen1.2.18