Main Page   Modules   Data Structures   File List   Data Fields   Globals   Related Pages  

gnn_momentum.c

Go to the documentation of this file.
00001 /***************************************************************************
00002  *  @file gnn_momentum.c
00003  *  @brief Gradient Descent with Momentum Trainer Implementation.
00004  *
00005  *  @date   : 31-08-03 23:48
00006  *  @author : Pedro Ortega C. <peortega@dcc.uchile.cl>
00007  *  Copyright  2003  Pedro Ortega C.
00008  ****************************************************************************/
00009 /*
00010  *  This program is free software; you can redistribute it and/or modify
00011  *  it under the terms of the GNU General Public License as published by
00012  *  the Free Software Foundation; either version 2 of the License, or
00013  *  (at your option) any later version.
00014  *
00015  *  This program is distributed in the hope that it will be useful,
00016  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
00017  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00018  *  GNU Library General Public License for more details.
00019  *
00020  *  You should have received a copy of the GNU General Public License
00021  *  along with this program; if not, write to the Free Software
00022  *  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
00023  */
00024 
00025 
00026 
00027 /**
00028  * @defgroup gnn_momentum_doc gnn_momentum : Gradient Descent with Momentum Term.
00029  * @ingroup gnn_trainer_doc
00030  *
00031  * The present trainer provides an implementation of the gradient descent
00032  * with momentum term algorithm for parameter optimization. At each step,
00033  * the parameters are updated using the following rule:
00034  * \f[ \Delta w \leftarrow - \mu \frac{\partial E}{\partial w}
00035  *                         + \eta \Delta w
00036  * \f]
00037  * where \f$ \mu \f$ is the "learning rate", \f$ \eta \f$ is the
00038  * "momentum rate", and \f$\Delta w\f$ is the update.
00039  *
00040  * The learning rate \f$\mu\f$ is tipically one order of magnitude smaller than
00041  * the one used in the gradient descent algorithm. On the other hand, a tipical
00042  * choice of \f$\eta\f$ is \f$ \eta = 0.8 \f$.
00043  */
00044 
00045 
00046 
00047 /******************************************/
00048 /* Include Files                          */
00049 /******************************************/
00050 
00051 #include "gnn_momentum.h"
00052 
00053 
00054 
00055 /******************************************/
00056 /* Static Declaration                     */
00057 /******************************************/
00058 
00059 typedef struct _gnn_momentum gnn_momentum;
00060 
00061 struct _gnn_momentum
00062 {
00063     gnn_trainer trainer;
00064     gsl_vector *w;
00065     gsl_vector *mw;
00066     double mu;
00067     double eta;
00068 };
00069 
00070 
00071 static int
00072 gnn_momentum_train (gnn_trainer   *trainer);
00073 
00074 static void
00075 gnn_momentum_destroy (gnn_trainer *trainer);
00076 
00077 
00078 
00079 /******************************************/
00080 /* Static Implementation                  */
00081 /******************************************/
00082 
00083 /**
00084  * @brief Reset function.
00085  * @ingroup gnn_momentum_doc
00086  *
00087  * @param  trainer A pointer to a \ref gnn_momentum.
00088  * @return Returns 0 if succeeded.
00089  */
00090 static int
00091 gnn_momentum_reset (gnn_trainer *trainer)
00092 {
00093     gnn_momentum *mtrainer;
00094 
00095     assert (trainer != NULL);
00096 
00097     mtrainer = (gnn_momentum *) trainer;
00098 
00099     /* get parameters */
00100     gnn_node_param_get (trainer->node, mtrainer->w);
00101 
00102     /* reset momenta */
00103     gsl_vector_set_zero (mtrainer->mw);
00104 
00105     return 0;
00106 }
00107 
00108 /**
00109  * @brief Train function.
00110  * @ingroup gnn_momentum_doc
00111  *
00112  * @param  trainer A pointer to a \ref gnn_momentum.
00113  * @return Returns 0 if succeeded.
00114  */
00115 static int
00116 gnn_momentum_train (gnn_trainer   *trainer)
00117 {
00118     gsl_vector *dw;
00119     gnn_momentum *mtrainer;
00120 
00121     /* get view */
00122     mtrainer = (gnn_momentum *) trainer;
00123 
00124     /* process minibatch */
00125     gnn_trainer_batch_process (trainer);
00126     
00127     /* get gradient */
00128     dw = gnn_trainer_batch_get_dw (trainer);
00129     
00130     /* move to next minibatch */
00131     gnn_trainer_batch_next (trainer);
00132 
00133     /* scale with learning factor */
00134     gsl_vector_scale (dw, - mtrainer->mu);
00135 
00136     /* modify momenta */
00137     gsl_vector_scale (mtrainer->mw, mtrainer->eta);
00138     gsl_vector_add (mtrainer->mw, dw);
00139 
00140     /* sum to parameters */
00141     gsl_vector_add (mtrainer->w, mtrainer->mw);
00142 
00143     /* update parameters */
00144     gnn_node_param_set (trainer->node, mtrainer->w);
00145 
00146     return 0;
00147 }
00148 
00149 /**
00150  * @brief Destructor.
00151  * @ingroup gnn_momentum_doc
00152  *
00153  * @param  trainer A pointer to a \ref gnn_momentum.
00154  */
00155 static void
00156 gnn_momentum_destroy (gnn_trainer *trainer)
00157 {
00158     gnn_momentum *mtrainer;
00159 
00160     assert (trainer != NULL);
00161 
00162     mtrainer = (gnn_momentum *) trainer;
00163 
00164     if (mtrainer->w != NULL)
00165         gsl_vector_free (mtrainer->w);
00166     if (mtrainer->mw != NULL)
00167         gsl_vector_free (mtrainer->mw);
00168 
00169     return;
00170 }
00171 
00172 
00173 
00174 /******************************************/
00175 /* Public Interface                       */
00176 /******************************************/
00177 
00178 /**
00179  * @brief Creates a new gradient descent with momentum trainer.
00180  * @ingroup gnn_momentum_doc
00181  *
00182  * This function creates a new gradient descent with momentum trainer
00183  * (\ref gnn_momentum). It uses the learning rate \f$\mu\f$ given by "mu"
00184  * and the momentum rate \f$\eta\f$ given by "eta", where \f$\mu > 0\f$ and
00185  * \f$\eta \geq 0\f$.
00186  *
00187  * @param  node A pointer to a \ref gnn_node.
00188  * @param  crit A pointer to a \ref gnn_criterion.
00189  * @param  data A pointer to a \ref gnn_dataset.
00190  * @param  mu   The learning rate \f$\mu\f$.
00191  * @param  eta  The momentum rate \f$\eta\f$.
00192  * @return Returns a pointer to a new \ref gnn_momentum trainer.
00193  */
00194 gnn_trainer *
00195 gnn_momentum_new (gnn_node *node,
00196                   gnn_criterion *crit,
00197                   gnn_dataset *data,
00198                   double mu,
00199                   double eta)
00200 {
00201     int status;
00202     gnn_trainer *trainer;
00203     gnn_momentum *mtrainer;
00204 
00205     /* check that mu isn't negative */
00206     if (mu <= 0.0)
00207     {
00208         GSL_ERROR_VAL ("learning factor should be stricly positive",
00209                        GSL_EINVAL, NULL);
00210     }
00211     /* check that eta isn't negative */
00212     if (eta < 0.0)
00213     {
00214         GSL_ERROR_VAL ("momentum factor should be positive",
00215                        GSL_EINVAL, NULL);
00216     }
00217 
00218     /* allocate memory for the trainer */
00219     mtrainer = (gnn_momentum *) malloc (sizeof (gnn_momentum));
00220     if (mtrainer == NULL)
00221     {
00222         GSL_ERROR_VAL ("couldn't allocate memory for gnn_momentum",
00223                        GSL_ENOMEM, NULL);
00224     }
00225 
00226     /* get view as gnn_trainer */
00227     trainer = (gnn_trainer *) mtrainer;
00228 
00229     /* initialize */
00230     status = gnn_trainer_init (trainer,
00231                                "gnn_momentum",
00232                                node,
00233                                crit,
00234                                data,
00235                                gnn_momentum_reset,
00236                                gnn_momentum_train,
00237                                gnn_momentum_destroy);
00238     if (status)
00239     {
00240         GSL_ERROR_VAL ("couldn't initialize gnn_momentum",
00241                        GSL_EFAILED, NULL);
00242     }
00243 
00244     /* set fields */
00245     mtrainer->mu  = mu;
00246     mtrainer->eta = eta;
00247     mtrainer->w   = gsl_vector_alloc (gnn_node_param_get_size (node));
00248     mtrainer->mw  = gsl_vector_alloc (gnn_node_param_get_size (node));
00249     if (mtrainer->w == NULL || mtrainer->mw == NULL)
00250     {
00251         gnn_trainer_destroy (trainer);
00252         GSL_ERROR_VAL ("couldn't allocate memory for gnn_momentum",
00253                        GSL_ENOMEM, NULL);
00254     }
00255 
00256     return trainer;
00257 }
00258 
00259 /**
00260  * @brief Gets the learning rate.
00261  * @ingroup gnn_momentum_doc
00262  *
00263  * This function returns the learning rate \f$\mu\f$ used by the gradient
00264  * descent with momentum trainer.
00265  *
00266  * @param  trainer A pointer to a \ref gnn_momentum.
00267  * @return Returns the learning rate \f$\mu\f$.
00268  */
00269 double
00270 gnn_momentum_get_mu (gnn_trainer *trainer)
00271 {
00272     gnn_momentum *mtrainer;
00273 
00274     assert (trainer != NULL);
00275 
00276     mtrainer = (gnn_momentum *) trainer;
00277     return mtrainer->mu;
00278 }
00279 
00280 /**
00281  * @brief Sets the learning rate.
00282  * @ingroup gnn_momentum_doc
00283  *
00284  * This function sets a new value for the learning rate \f$\mu\f$ used by
00285  * the gradient descent with momentum trainer. The learning rate should
00286  * be strictly positive.
00287  *
00288  * @param  trainer A pointer to a \ref gnn_momentum.
00289  * @param  mu      The learning rate \f$\mu\f$.
00290  * @return Returns 0 if suceeded.
00291  */
00292 int
00293 gnn_momentum_set_mu (gnn_trainer *trainer, double mu)
00294 {
00295     gnn_momentum *mtrainer;
00296 
00297     assert (trainer != NULL);
00298 
00299     /* check learning factor */
00300     if (mu <= 0.0)
00301         GSL_ERROR ("learning factor should be stricly positive", GSL_EINVAL);
00302 
00303     /* get view */
00304     mtrainer = (gnn_momentum *) trainer;
00305 
00306     /* set new learning factor */
00307     mtrainer->mu = mu;
00308 
00309     return 0;
00310 }
00311 
00312 /**
00313  * @brief Gets the momentum rate.
00314  * @ingroup gnn_momentum_doc
00315  *
00316  * This function returns the momentum rate \f$\eta\f$ used by the gradient
00317  * descent with momentum trainer.
00318  *
00319  * @param  trainer A pointer to a \ref gnn_momentum.
00320  * @return Returns the learning rate \f$\eta\f$.
00321  */
00322 double
00323 gnn_momentum_get_eta (gnn_trainer *trainer)
00324 {
00325     gnn_momentum *mtrainer;
00326 
00327     assert (trainer != NULL);
00328 
00329     mtrainer = (gnn_momentum *) trainer;
00330     return mtrainer->eta;
00331 }
00332 
00333 /**
00334  * @brief Sets the momentum rate.
00335  * @ingroup gnn_momentum_doc
00336  *
00337  * This function sets a new value for the momentum rate \f$\eta\f$ used by
00338  * the gradient descent with momentum trainer. The momentum rate should
00339  * be positive.
00340  *
00341  * @param  trainer A pointer to a \ref gnn_momentum.
00342  * @param  eta     The momentum rate \f$\eta\f$.
00343  * @return Returns 0 if suceeded.
00344  */
00345 int
00346 gnn_momentum_set_eta (gnn_trainer *trainer, double eta)
00347 {
00348     gnn_momentum *mtrainer;
00349 
00350     assert (trainer != NULL);
00351 
00352     /* check learning factor */
00353     if (eta < 0.0)
00354         GSL_ERROR ("momentum factor should be positive", GSL_EINVAL);
00355 
00356     /* get view */
00357     mtrainer = (gnn_momentum *) trainer;
00358 
00359     /* set new learning factor */
00360     mtrainer->eta = eta;
00361 
00362     return 0;
00363 }
00364 
00365 
00366 
00367 

Generated on Sun Jun 13 20:50:12 2004 for libgnn Gradient Retropropagation Machine Library by doxygen1.2.18