Main Page   Modules   Data Structures   File List   Data Fields   Globals   Related Pages  

gnn_weight_elimination.c

Go to the documentation of this file.
00001 /***************************************************************************
00002  *  @file gnn_weight_elimination.c
00003  *  @brief gnn_weight_elimination Implementation.
00004  *
00005  *  @date   : 05-10-03 19:23
00006  *  @author : Pedro Ortega C. <peortega@dcc.uchile.cl>
00007  *  Copyright  2003  Pedro Ortega C.
00008  ****************************************************************************/
00009 /*
00010  *  This program is free software; you can redistribute it and/or modify
00011  *  it under the terms of the GNU General Public License as published by
00012  *  the Free Software Foundation; either version 2 of the License, or
00013  *  (at your option) any later version.
00014  *
00015  *  This program is distributed in the hope that it will be useful,
00016  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
00017  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00018  *  GNU Library General Public License for more details.
00019  *
00020  *  You should have received a copy of the GNU General Public License
00021  *  along with this program; if not, write to the Free Software
00022  *  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
00023  */
00024 
00025 /**
00026  * @brief Weight Elimination Regularization.
00027  * @defgroup gnn_weight_elimination_doc gnn_weight_elimination : Weight Elimination Regularization.
00028  * @ingroup gnn_criterion_doc
00029  *
00030  * This datatype implements the \em Weight \em Elimination Regularization,
00031  * given by the error penalty term:
00032  *
00033  * \f[ \Omega = \sum_i \frac{ w_i^2 }{ \hat{w}^2 + w_i^2 } \f]
00034  *
00035  * Where the \f$w_i, i=0,\ldots,L\f$ are parameters that have been chosen
00036  * from a \ref gnn_node by hand. They are usually those of a \ref gnn_weight
00037  * linear transform node. The term \f$\hat{w}\f$ is a scale parameter that
00038  * usually is taken of order unity. The \em weight \em elimination regularizer
00039  * will favour few large terms rather than many small ones, like in
00040  * \em weight \em decay (\ref gnn_weight_decay).
00041  *
00042  * Although the weights never fall to zero in practice, after training the
00043  * associated \ref gnn_node a while, you can call the
00044  * \ref gnn_weight_elimination_prun function, which deletes the weights
00045  * below a given threshold.
00046  *
00047  * As in general with all regularizers, the \ref gnn_weight_elimination doesn't
00048  * work by itself, and needs a \ref gnn_criterion to build upon.
00049  *
00050  * The resulting error function is:
00051  *
00052  * \f[ \widetilde{E} = E + \nu \Omega \f]
00053  *
00054  * where \f$\nu\f$ is the regularization coeffcient. The corresponding
00055  * gradient is:
00056  * \f[ \frac{\partial \widetilde{E}}{\partial w_l} =
00057  *         2 w_i \frac{ \hat{w}^2 }{ (\hat{w}^2 + w_i ^2)^2 }
00058  * \f]
00059  */
00060 
00061 
00062 
00063 /******************************************/
00064 /* Include Files                          */
00065 /******************************************/
00066 
00067 #include <math.h>
00068 #include "gnn_utilities.h"
00069 #include "gnn_weight_elimination.h"
00070 
00071 
00072 
00073 /******************************************/
00074 /* Static Declaration                     */
00075 /******************************************/
00076 
00077 typedef struct _gnn_weight_elimination gnn_weight_elimination;
00078 
00079 struct _gnn_weight_elimination
00080 {
00081     gnn_criterion  crit;
00082     gnn_criterion *subcrit;
00083     gnn_pbundle   *pb;
00084     gsl_vector    *w;
00085     gsl_vector    *dw;
00086     double         nu;
00087     double         wp;
00088 };
00089 
00090 
00091 double
00092 gnn_weight_elimination_e (gnn_criterion *crit,
00093                           const gsl_vector *y,
00094                           const gsl_vector *t);
00095 
00096 int
00097 gnn_weight_elimination_dy (gnn_criterion *crit,
00098                            const gsl_vector *y,
00099                            const gsl_vector *t,
00100                            gsl_vector * dy);
00101 
00102 static void
00103 gnn_weight_elimination_destroy (gnn_criterion *crit);
00104 
00105 
00106 
00107 /******************************************/
00108 /* Static Implementation                  */
00109 /******************************************/
00110 
00111 /**
00112  * @brief The evaluation function.
00113  * @ingroup gnn_weight_elimination_doc
00114  *
00115  * This function corresponds to the evaluation function.
00116  *
00117  * @param  crit A pointer to a \ref gnn_weight_elimination criterion.
00118  * @param  y    A pointer to an estimation vector \f$y\f$.
00119  * @param  t    A pointer to the desired target vector \f$t\f$.
00120  * @return A real number corresponding to the value of the criterion.
00121  */
00122 double
00123 gnn_weight_elimination_e (gnn_criterion *crit,
00124                           const gsl_vector *y,
00125                           const gsl_vector *t)
00126 {
00127     size_t i;
00128     double E;
00129     double Ep;
00130     double wsum;
00131     gnn_weight_elimination *we;
00132 
00133     assert (crit != NULL);
00134     assert (y != NULL);
00135     assert (t != NULL);
00136 
00137     /* get view as a gnn_weight_elimination */
00138     we = (gnn_weight_elimination *) crit;
00139 
00140     /* check sizes */
00141     if (y->size != t->size)
00142         GSL_ERROR_VAL ("vector sizes should be the same", GSL_EINVAL, 0.0);
00143 
00144     /* compute subnode error */
00145     E = gnn_criterion_evaluate_e (we->subcrit, y, t);
00146 
00147     /* get parameter vector */
00148     gnn_pbundle_get_w (we->pb, we->w);
00149 
00150     /* compute penalization term */
00151     Ep = 0.0;
00152     for (i=0; i<we->w->size; ++i)
00153     {
00154         double wi;
00155         double wpwi;
00156 
00157         wi   = gsl_vector_get (we->w, i);
00158         wpwi = we->wp * we->wp + wi * wi;
00159         wpwi = GNN_MAX (GNN_TINY, wpwi);
00160         Ep += wi * wi / wpwi;
00161     }
00162 
00163     return E + 0.5 * we->nu * Ep;
00164 }
00165 
00166 /**
00167  * @brief The gradient evaluation function.
00168  * @ingroup gnn_weight_elimination_doc
00169  *
00170  * This function implements criterion's gradient evaluation function.
00171  *
00172  * @param  crit A pointer to a \ref gnn_weight_elimination criterion.
00173  * @param  y    A pointer to an estimation vector \f$y\f$.
00174  * @param  t    A pointer to the desired target vector \f$t\f$.
00175  * @param  dy   A pointer to a buffer vector where the result should be placed.
00176  * @return Returns 0 if succeeded.
00177  */
00178 int
00179 gnn_weight_elimination_dy (gnn_criterion *crit,
00180                            const gsl_vector *y,
00181                            const gsl_vector *t,
00182                            gsl_vector * dy)
00183 {
00184     int i;
00185     gnn_weight_elimination *we;
00186 
00187     assert (crit != NULL);
00188     assert (y    != NULL);
00189     assert (t    != NULL);
00190     assert (dy   != NULL);
00191 
00192     /* get view as a gnn_weight_elimination */
00193     we = (gnn_weight_elimination *) crit;
00194 
00195     /* compute subnode dy */
00196     gnn_criterion_evaluate_dy (we->subcrit, dy);
00197 
00198     /* get parameter gradient vector */
00199     gnn_pbundle_get_dw (we->pb, we->dw);
00200 
00201     /* compute gradient penalization terms */
00202     for (i=0; i<we->w->size; ++i)
00203     {
00204         double wi;
00205         double dwi;
00206         double wpwi2;
00207 
00208         wi  = gsl_vector_get (we->w, i);
00209         dwi = gsl_vector_get (we->dw, i);
00210 
00211         wpwi2 = we->wp * we->wp + wi * wi;
00212         wpwi2 = GNN_MAX (GNN_TINY, wpwi2 * wpwi2);
00213 
00214         dwi += we->nu * 2 * wi * (we->wp * we->wp / wpwi2);
00215 
00216         gsl_vector_set (we->dw, i, dwi);
00217     }
00218 
00219     /* set parameter gradient vector penalization */
00220     gnn_pbundle_set_dw (we->pb, we->dw);
00221 
00222     return 0;
00223 }
00224 
00225 /**
00226  * @brief The destroy function.
00227  * @ingroup gnn_weight_elimination_doc
00228  *
00229  * This function implements destroy function.
00230  *
00231  * @param  crit A pointer to a \ref gnn_weight_elimination criterion.
00232  */
00233 static void
00234 gnn_weight_elimination_destroy (gnn_criterion *crit)
00235 {
00236     gnn_weight_elimination *we;
00237 
00238     assert (crit != NULL);
00239 
00240     /* get view as a gnn_weight_elimination */
00241     we = (gnn_weight_elimination *) crit;
00242 
00243     /* free allocate memory */
00244     if (we->subcrit != NULL)
00245         gnn_criterion_destroy (we->subcrit);
00246     if (we->pb != NULL)
00247         gnn_pbundle_destroy (we->pb);
00248     if (we->w != NULL)
00249         gsl_vector_free (we->w);
00250     if (we->dw != NULL)
00251         gsl_vector_free (we->dw);
00252 }
00253 
00254 
00255 
00256 /******************************************/
00257 /* Public Interface                       */
00258 /******************************************/
00259 
00260 /**
00261  * @brief Creates a Weight Elimination regularization for \ref gnn_weight.
00262  * @ingroup gnn_weight_elimination_doc
00263  *
00264  * This function builds a new \ref gnn_weight_elimination regularizer for
00265  * all \ref gnn_weight nodes contained in the given node \a node.
00266  *
00267  * @param  crit A pointer to a \ref gnn_weight_elimination criterion.
00268  * @param  nu   The regularization coefficient \f$\nu\f$.
00269  * @param  wp   The scale factor \f$\hat{w}\f$.
00270  * @param  node A pointer to a \ref gnn_node.
00271  * @return Returns a pointer to a \ref gnn_weight_elimination
00272  *         if succeeded or NULL.
00273  */
00274 gnn_criterion *
00275 gnn_weight_elimination_new (gnn_criterion *crit,
00276                             double         nu,
00277                             double         wp,
00278                             gnn_node *node)
00279 {
00280     gnn_pbundle *pb;
00281     gnn_weight_elimination *we;
00282 
00283     assert (crit != NULL);
00284     assert (node != NULL);
00285 
00286     /* get parameter bundle consisting of gnn_weight nodes' params */
00287     pb = gnn_node_sub_search_params (node, "gnn_weight");
00288 
00289     /* create weight elimination penalizer */
00290     we = (gnn_weight_elimination *)
00291           gnn_weight_elimination_new_with_pbundle (crit, nu, wp, pb);
00292     if (we == NULL)
00293     {
00294         gnn_pbundle_destroy (pb);
00295         GSL_ERROR_VAL ("couldn't create gnn_weight_elimination regularizer",
00296                        GSL_EFAILED, NULL);
00297     }
00298 
00299     return (gnn_criterion *) we;
00300 }
00301 
00302 /**
00303  * @brief Creates a Weight Elimination regularization for a given type.
00304  * @ingroup gnn_weight_elimination_doc
00305  *
00306  * This function builds a new \ref gnn_weight_elimination regularizer for
00307  * all parameters that pertain to nodes of type \a type contained in the
00308  * given node \a node.
00309  *
00310  * @param  crit A pointer to a \ref gnn_weight_elimination criterion.
00311  * @param  nu   The regularization coefficient \f$\nu\f$.
00312  * @param  wp   The scale parameter \f$\hat{w}\f$.
00313  * @param  node A pointer to a \ref gnn_node.
00314  * @param  type A string containing the type.
00315  * @return Returns a pointer to a \ref gnn_weight_elimination if succeeded or NULL.
00316  */
00317 gnn_criterion *
00318 gnn_weight_elimination_new_with_type (gnn_criterion *crit,
00319                                       double         nu,
00320                                       double         wp,
00321                                       gnn_node      *node,
00322                                       const char    *type)
00323 {
00324     gnn_pbundle *pb;
00325     gnn_weight_elimination *we;
00326 
00327     assert (crit != NULL);
00328     assert (node != NULL);
00329     assert (type != NULL);
00330 
00331     /* get parameter bundle consisting of  */
00332     /* the nodes' params of the given type */
00333     pb = gnn_node_sub_search_params (node, type);
00334 
00335     /* create weight decay penalizer */
00336     we = (gnn_weight_elimination *)
00337           gnn_weight_elimination_new_with_pbundle (crit, nu, wp, pb);
00338     if (we == NULL)
00339     {
00340         gnn_pbundle_destroy (pb);
00341         GSL_ERROR_VAL ("couldn't create gnn_weight_elimination regularizer",
00342                        GSL_EFAILED, NULL);
00343     }
00344 
00345     return (gnn_criterion *) we;
00346 }
00347 
00348 /**
00349  * @brief Creates a Weight Elimination regularization for a given \em pbundle.
00350  * @ingroup gnn_weight_elimination_doc
00351  *
00352  * This function builds a new \ref gnn_weight_elimination regularizer for
00353  * all parameters contained in the given parameter bundle \a pb.
00354  *
00355  * @param  crit A pointer to a \ref gnn_weight_elimination criterion.
00356  * @param  nu   The regularization coefficient \f$\nu\f$.
00357  * @param  wp   The scale parameter \f$\hat{w}\f$.
00358  * @param  pb   A \ref gnn_pbundle.
00359  * @return Returns a pointer to a \ref gnn_weight_elimination
00360  *         if succeeded or NULL.
00361  */
00362 gnn_criterion *
00363 gnn_weight_elimination_new_with_pbundle (gnn_criterion *crit,
00364                                          double         nu,
00365                                          double         wp,
00366                                          gnn_pbundle   *pb)
00367 {
00368     int status;
00369     size_t l;
00370     size_t size;
00371     gnn_criterion *c;
00372     gnn_weight_elimination *we;
00373 
00374     assert (crit != NULL);
00375     assert (pb != NULL);
00376 
00377     /* check for positive penalty term coefficient */
00378     if (nu < 0.0)
00379     {
00380         GSL_ERROR_VAL ("penalty term coefficient should be positive",
00381                        GSL_EINVAL, NULL);
00382     }
00383 
00384     /* check for positive weight scale factor */
00385     if (wp <= 0.0)
00386     {
00387         GSL_ERROR_VAL ("weight scale factor should be strictly positive",
00388                        GSL_EINVAL, NULL);
00389     }
00390 
00391     /* alloc memory */
00392     we = (gnn_weight_elimination *) malloc (sizeof (gnn_weight_elimination));
00393     if (we == NULL)
00394     {
00395         GSL_ERROR_VAL ("couldn't alloc memory for gnn_weight_elimination",
00396                        GSL_ENOMEM, NULL);
00397     }
00398 
00399     /* get view as gnn_criterion */
00400     c = (gnn_criterion *) we;
00401 
00402     /* get size */
00403     size = gnn_criterion_get_size (crit);
00404 
00405     /* initialize */
00406     status = gnn_criterion_init  (c,
00407                                   "gnn_weight_elimination",
00408                                   size,
00409                                   gnn_weight_elimination_e,
00410                                   gnn_weight_elimination_dy,
00411                                   gnn_weight_elimination_destroy);
00412     if (status)
00413     {
00414         gnn_criterion_destroy (c);
00415         GSL_ERROR_VAL ("couldn't initialize gnn_weight_elimination",
00416                        GSL_EFAILED, NULL);
00417     }
00418 
00419     /* get number of penalized parameters */
00420     l = gnn_pbundle_get_size (pb);
00421 
00422     /* set fields */
00423     we->subcrit = crit;
00424     we->pb      = pb;
00425     we->w       = gsl_vector_alloc (l);
00426     we->dw      = gsl_vector_alloc (l);
00427     we->nu      = nu;
00428     we->wp      = wp;
00429 
00430     if (we->dw == NULL || we->w == NULL)
00431     {
00432         gnn_criterion_destroy (c);
00433         GSL_ERROR_VAL ("couldn't allocate memory for internal "
00434                        "gnn_weight_elimination buffer", GSL_EFAILED, NULL);
00435     }
00436 
00437     return c;
00438 }
00439 
00440 /**
00441  * @brief Sets the regularization coefficient.
00442  * @ingroup gnn_weight_elimination_doc
00443  *
00444  * This function sets a new regularization coefficient for the
00445  * \ref gnn_weight_elimination regularizer. It should be positive.
00446  *
00447  * @param  crit A pointer to a \ref gnn_weight_elimination criterion.
00448  * @param  nu   The new regularization coefficient \f$nu\f$.
00449  * @return Returns 0 if succeeded.
00450  */
00451 int
00452 gnn_weight_elimination_set_nu (gnn_criterion *crit, double nu)
00453 {
00454     gnn_weight_elimination *we;
00455 
00456     assert (crit != NULL);
00457 
00458     if (nu < 0.0)
00459     {
00460         GSL_ERROR ("penalty term coefficient should be positive", GSL_EINVAL);
00461     }
00462 
00463     we = (gnn_weight_elimination *) crit;
00464     we->nu = nu;
00465 
00466     return 0;
00467 }
00468 
00469 /**
00470  * @brief Gets the regularization coefficient.
00471  * @ingroup gnn_weight_elimination_doc
00472  *
00473  * This function gets the currently used regularization coefficient of the
00474  * \ref gnn_weight_elimination regularizer.
00475  *
00476  * @param  crit A pointer to a \ref gnn_weight_elimination criterion.
00477  * @return Returns the regularization coefficient \f$nu\f$.
00478  */
00479 double
00480 gnn_weight_elimination_get_nu (gnn_criterion *crit)
00481 {
00482     gnn_weight_elimination *we;
00483 
00484     assert (crit != NULL);
00485 
00486     we = (gnn_weight_elimination *) crit;
00487     return we->nu;
00488 }
00489 
00490 /**
00491  * @brief Sets the scale parameter \f$\hat{w}\f$.
00492  * @ingroup gnn_weight_elimination_doc
00493  *
00494  * This function sets a new scale paramter \f$\hat{w}\f$ for the
00495  * \ref gnn_weight_elimination regularizer. It should be stricly positive.
00496  *
00497  * @param  crit A pointer to a \ref gnn_weight_elimination criterion.
00498  * @param  wp   The new scale parameter \f$\hat{w}\f$.
00499  * @return Returns 0 if succeeded.
00500  */
00501 int
00502 gnn_weight_elimination_set_wp (gnn_criterion *crit, double wp)
00503 {
00504     gnn_weight_elimination *we;
00505 
00506     assert (crit != NULL);
00507 
00508     if (wp <= 0.0)
00509     {
00510         GSL_ERROR ("weight scale coefficient should be strictly positive",
00511                    GSL_EINVAL);
00512     }
00513 
00514     we = (gnn_weight_elimination *) crit;
00515     we->wp = wp;
00516 
00517     return 0;
00518 }
00519 
00520 /**
00521  * @brief Gets the current scale parameter \f$\hat{w}\f$.
00522  * @ingroup gnn_weight_elimination_doc
00523  *
00524  * This function gets the currently used scale parameter \f$\hat{w}\f$ by the
00525  * \ref gnn_weight_elimination regularizer.
00526  *
00527  * @param  crit A pointer to a \ref gnn_weight_elimination criterion.
00528  * @return Returns the scale parameter \f$\hat{w}\f$.
00529  */
00530 double
00531 gnn_weight_elimination_get_wp (gnn_criterion *crit)
00532 {
00533     gnn_weight_elimination *we;
00534 
00535     assert (crit != NULL);
00536 
00537     we = (gnn_weight_elimination *) crit;
00538     return we->wp;
00539 }
00540 
00541 /**
00542  * @brief Perform parameter pruning.
00543  * @ingroup gnn_weight_elimination_doc
00544  *
00545  * This function pruns the penalized parameters that fall below the given
00546  * threshold \a th in magnitude. These are frozen to zero.
00547  *
00548  * @param  crit A pointer to a \ref gnn_weight_elimination criterion.
00549  * @param  th   The threshold for parameter pruning.
00550  * @return Returns 0 if succeeded.
00551  */
00552 int
00553 gnn_weight_elimination_prun (gnn_criterion *crit, double th)
00554 {
00555     size_t i;
00556     size_t size;
00557     gnn_weight_elimination *we;
00558 
00559     assert (crit != NULL);
00560 
00561     /* get view as a weight elimination criterion */
00562     we = (gnn_weight_elimination *) crit;
00563 
00564     /* get number of parameters */
00565     size = gnn_pbundle_get_size (we->pb);
00566 
00567     /* check every parameter for its value and */
00568     /* prun if it falls below the threshold.   */
00569     for (i=0; i<size; ++i)
00570     {
00571         double wi;
00572         
00573         wi = gnn_pbundle_get_w_at (we->pb, i);
00574         if (fabs (wi) < th)
00575         {
00576             gnn_pbundle_set_w_at (we->pb, i, 0.0);
00577             gnn_pbundle_set_f_at (we->pb, i, 1);
00578         }
00579     }
00580     
00581     return 0;
00582 }
00583 
00584 
00585 

Generated on Sun Jun 13 20:50:13 2004 for libgnn Gradient Retropropagation Machine Library by doxygen1.2.18