Main Page   Modules   Data Structures   File List   Data Fields   Globals   Related Pages  

gnn_trainer.c

Go to the documentation of this file.
00001 /***************************************************************************
00002  *  @file gnn_trainer.c
00003  *  @brief Trainer Implementation.
00004  *
00005  *  @date   : 29-08-03 21:45
00006  *  @author : Pedro Ortega C. <peortega@dcc.uchile.cl>
00007  *  Copyright  2003  Pedro Ortega C.
00008  ****************************************************************************/
00009 /*
00010  *  This program is free software; you can redistribute it and/or modify
00011  *  it under the terms of the GNU General Public License as published by
00012  *  the Free Software Foundation; either version 2 of the License, or
00013  *  (at your option) any later version.
00014  *
00015  *  This program is distributed in the hope that it will be useful,
00016  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
00017  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00018  *  GNU Library General Public License for more details.
00019  *
00020  *  You should have received a copy of the GNU General Public License
00021  *  along with this program; if not, write to the Free Software
00022  *  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
00023  */
00024 
00025 
00026 
00027 /**
00028  * @defgroup gnn_trainer_doc gnn_trainer : Trainers for Models.
00029  * @ingroup libgnn_trainer
00030  * @todo Provide an API for those who want to implement the training procedure
00031  *       from scratch.
00032  *
00033  * The \ref gnn_trainer type defines a common interface for handling training
00034  * algorithms. It provides the basic structure on which all training algorithms
00035  * are based upon. Like all basic structures in \ref libgnn, the
00036  * \ref gnn_trainer structure cannot be used by itself. Instead, implementations
00037  * of trainers fill in the necessary parts to become a fully-functional
00038  * \ref gnn_trainer.
00039  *
00040  * What is a \ref gnn_trainer? A \ref gnn_trainer is an object that, given
00041  * a Model (in form of a \ref gnn_node), a criterion (\ref gnn_criterion)
00042  * and a dataset (\ref gnn_dataset):
00043  * - presents the input patterns sequentially to the Model, that is
00044  *   \f$x^1, x^2, x^3, \ldots, x^N\f$, where \f$x^k\f$ denotes the \f$k\f$-th
00045  *   input pattern, and \f$N\f$ is the order of the dataset,
00046  * - obtains its outputs \f$y^k = f(x^k)\f$ for \f$k=0,\ldots,N-1\f$,
00047  * - compares them with the target patterns using the criterion
00048  * - computes the gradient with respect to the parameters and
00049  * - adjusts them.
00050  * These steps are performed in order to optimize the Model.
00051  *
00052  * The patterns are presented in order, one after the other, until the last
00053  * pattern has been presented. After then, the dataset is reset and the
00054  * procedure is repeated. This is what is called an epoch: a cycle where
00055  * each pattern is evaluated once.
00056  *
00057  * Although one pattern is presented at a time, the needed gradient
00058  * \f$\frac{\partial E}{\partial w}\f$ is estimated by perfoming a weighted
00059  * sum over the particular gradients \f$\frac{\partial E}{\partial w}|_{x^k}\f$
00060  * for each pattern \f$x^k\f$:
00061  * \f[ \frac{\partial \hat{E}}{\partial w} =
00062  *     \frac{ \sum_k p_k
00063  *            { \left . \frac{\partial E}{\partial w} \right | }_{x_k}
00064  *          }
00065  *          { \sum_k p_k }
00066  * \f]
00067  * where \f$p_k\f$ is the weight of the pattern \f$x^k\f$. How many patterns
00068  * are considered in this sum? Well, in the literature, when only one pattern
00069  * is considered, then the procedure is called on-line training, and in a
00070  * batch training, all patterns are taken. In \ref libgnn, this sum is taken
00071  * over so-called minibatches, which are simply groups of patterns of the same
00072  * size. E.g. if there a 10 patterns, and the minibatches are of size 4, then
00073  * the patterns are presented in the following groups:
00074  * \f[ \{x^1, x^2, x^3, x^4\}, \{x^5, x^6, x^7, x^8\}, \{x^9, x^{10}\} \f]
00075  * (Please note that the last minibatch has been cut down to fit.)
00076  * This operation is executed automatically by calling \ref gnn_trainer_train
00077  * on a \ref gnn_trainer. \ref gnn_trainer_train returns the mean cost
00078  * of all patterns presented to the model during the current epoch and adjusts
00079  * the patterns of the model.
00080  *
00081  * To reset a trainer and start over again (the parameters are kept),
00082  * call \ref gnn_trainer_reset.
00083  *
00084  * How to extend the \ref gnn_trainer structure to implement your own
00085  * trainer? In the same spirit as in all other basic structures, you can
00086  * use the \ref gnn_trainer structure itself or extend it like the following
00087  * hypothetical "my_trainer" structure:
00088  * \code
00089  *     typedef struct _my_trainer
00090  *     {
00091  *         gnn_trainer trainer;
00092  *
00093  *         // the specific data for "my_trainer"
00094  *         ...
00095  *     } my_trainer;
00096  * \endcode
00097  * This allows you to cast a "my_trainer" structure to a \ref gnn_trainer
00098  * structure whenever needed.
00099  *
00100  * Secondly, you have to provide at least four functions: a constructor,
00101  * which allocates and initializes a new "my_trainer" structure; a "reset"
00102  * function to reset the training procedure, which should be of type
00103  * \ref gnn_trainer_reset_type; a "train" function to train a minibatch
00104  * and update the parameters, which should be of type
00105  * \ref gnn_trainer_train_type; and a "destroy" function, which destroys
00106  * all additional data used by the "my_trainer" structure, of type
00107  * \ref gnn_trainer_destroy_type (alternatively, you can use the default
00108  * implementations for the "reset" and "destroy" functions).
00109  *
00110  * Please read the specific documentation for the function types that you
00111  * should provide: \ref gnn_trainer_reset_type and \ref gnn_trainer_train_type.
00112  */
00113 
00114 
00115 
00116 /******************************************/
00117 /* Include Files                          */
00118 /******************************************/
00119 
00120 #include <string.h>
00121 #include "gnn_trainer.h"
00122 #include "gnn_utilities.h"
00123 
00124 
00125 
00126 /******************************************/
00127 /* Static Declaration                     */
00128 /******************************************/
00129 
00130 static int
00131 gnn_trainer_default_reset (gnn_trainer *trainer);
00132 
00133 static void
00134 gnn_trainer_default_destroy (gnn_trainer *trainer);
00135 
00136 
00137 
00138 /******************************************/
00139 /* Static Implementation                  */
00140 /******************************************/
00141 
00142 /**
00143  * @brief Default "reset" function for a trainer.
00144  * @ingroup gnn_trainer_doc
00145  *
00146  * This function is the default "reset" function for a trainer. It does nothing.
00147  *
00148  * @param  trainer A pointer to a \ref gnn_trainer.
00149  * @return 0 if succeeded.
00150  */
00151 static int
00152 gnn_trainer_default_reset (gnn_trainer *trainer)
00153 {
00154     assert (trainer != NULL);
00155 
00156     return 0;
00157 }
00158 
00159 /**
00160  * @brief Default "destroy" function for a trainer.
00161  * @ingroup gnn_trainer_doc
00162  *
00163  * This function is the default "destroy" function for a trainer. It assumes
00164  * that there isn't any additional data for the specific trainer type, so
00165  * it actually just returns.
00166  *
00167  * @param trainer A pointer to a \ref gnn_trainer.
00168  */
00169 static void
00170 gnn_trainer_default_destroy (gnn_trainer *trainer)
00171 {
00172     return;
00173 }
00174 
00175 
00176 
00177 /******************************************/
00178 /* Public Interface                       */
00179 /******************************************/
00180 
00181 /**
00182  * @brief Get the trainer's type.
00183  * @ingroup gnn_trainer_doc
00184  *
00185  * This function returns a string which contains the name of the trainer's
00186  * type. The string should not be modified.
00187  *
00188  * @param  trainer A pointer to a \ref gnn_trainer.
00189  * @return Returns a pointer to the trainer's own type string.
00190  */
00191 const char *
00192 gnn_trainer_get_type (gnn_trainer *trainer)
00193 {
00194     assert (trainer != NULL);
00195     
00196     return trainer->type;
00197 }
00198 
00199 /**
00200  * @brief Initializes a \ref gnn_trainer structure.
00201  * @ingroup gnn_trainer_doc
00202  *
00203  * This function initializes a \ref gnn_trainer structure by setting its
00204  * internal fields adecuately. Every extension to the \ref gnn_trainer
00205  * structure should have a constructor that calls this function to:
00206  *
00207  * - define the trainer's type,
00208  * - set the model to be trained (the node),
00209  * - set the criterion to be used,
00210  * - set the dataset with the patterns to be used during training,
00211  * - install the type-specific functions for "reset", "train" and "destroy"
00212  *
00213  * All arguments are mandatory, except for the "reset" and the "destroy"
00214  * functions: if the are specified as NULL, then the defaults are installed
00215  * instead (see \ref gnn_trainer_default_reset and
00216  * \ref gnn_trainer_default_destroy).
00217  *
00218  * The associated sizes should match. That is, the input pattern's sizes
00219  * contained in the dataset should corresponde to the node's input size.
00220  * Likewise, the output pattern's sizes should match with the node's output
00221  * size and the criterion's size.
00222  *
00223  * The following example should clarify its use. Suppose you want to create
00224  * a new trainer type, called "my_trainer". You have already implemented the
00225  * two of the needed functions, "my_trainer_reset", "my_trainer_train".
00226  * Actually, your special training function does not need any additional
00227  * data, and so the destructor isn't needed. Then, the following code could
00228  * be a possible implementation of the constructor:
00229  * \code
00230    gnn_trainer *
00231    my_trainer_new (gnn_node *model, gnn_criterion *crit, gnn_dataset *data)
00232    {
00233         int status;
00234         my_trainer  *t;     // this is just an extension of gnn_trainer
00235         gnn_trainer *tview; // points to the same structure, but uses another
00236                             // view
00237         
00238         // alloc space
00239         t = (my_trainer *) malloc (sizeof (*t));
00240         if (t == NULL)
00241             print_error_message_and_return_error_code ();
00242 
00243         // cast to a gnn_trainer
00244         tview = (gnn_trainer *) t;
00245         
00246         // initialize
00247         status = gnn_trainer_init (tview,
00248                                    "my_trainer",
00249                                    model,
00250                                    crit,
00251                                    data,
00252                                    my_trainer_reset,
00253                                    my_trainer_reset,
00254                                    NULL);
00255         if (status)
00256             print_error_message_and_return_error_code ();
00257 
00258         // finish initialization
00259         ...
00260         
00261         return tview;
00262    }
00263  * \endcode
00264  *
00265  * @param  trainer A pointer to a \ref gnn_trainer.
00266  * @param  type    A string to the typename. It will be duplicated.
00267  * @param  node    A pointer to a \ref gnn_node.
00268  * @param  crit    A pointer to a \ref gnn_criterion.
00269  * @param  data    A pointer to a \ref gnn_data.
00270  * @param  reset   A pointer to the "reset" function (or NULL).
00271  * @param  train   A pointer to the "train" function.
00272  * @param  destroy A pointer to the "destroy" function (or NULL).
00273  * @return Returns 0 if initialization succeeded.
00274  */
00275 int
00276 gnn_trainer_init (gnn_trainer             *trainer,
00277                   const char              *type,
00278                   gnn_node                *node,
00279                   gnn_criterion           *crit,
00280                   gnn_dataset             *data,
00281                   gnn_trainer_reset_type   reset,
00282                   gnn_trainer_train_type   train,
00283                   gnn_trainer_destroy_type destroy)
00284 {
00285     size_t n;
00286     size_t m;
00287     size_t l;
00288 
00289     assert (trainer != NULL);
00290     
00291     /* check arguments */
00292     if (type == NULL)
00293         GSL_ERROR ("invalid trainer type", GSL_EINVAL);
00294     if (node == NULL)
00295         GSL_ERROR ("invalid node", GSL_EINVAL);
00296     if (crit == NULL)
00297         GSL_ERROR ("invalid criterion", GSL_EINVAL);
00298     if (data == NULL)
00299         GSL_ERROR ("invalid dataset", GSL_EINVAL);
00300     if (train == NULL)
00301         GSL_ERROR ("invalid train function", GSL_EINVAL);
00302 
00303     /* check sizes */
00304     if (gnn_dataset_input_get_size (data) != gnn_node_input_get_size (node))
00305         GSL_ERROR ("the input pattern's size does not match with the node's "
00306                    "input size", GSL_EINVAL);
00307     if (gnn_dataset_output_get_size (data) != gnn_node_output_get_size (node))
00308         GSL_ERROR ("the output pattern's size does not match with the node's "
00309                    "output size", GSL_EINVAL);
00310     if (gnn_dataset_output_get_size (data) != gnn_criterion_get_size (crit))
00311         GSL_ERROR ("the output pattern's sizes does not match with the "
00312                    "criterion's size", GSL_EINVAL);
00313 
00314     /* check if sizes are stricly positive */
00315     if (gnn_node_input_get_size (node) <= 0)
00316         GSL_ERROR ("the input size should be stricly positive",
00317                    GSL_EINVAL);
00318     if (gnn_node_output_get_size (node) <= 0)
00319         GSL_ERROR ("the output size should be stricly positive",
00320                    GSL_EINVAL);
00321     if (gnn_node_param_get_size (node) <= 0)
00322         GSL_ERROR ("the parameter size should be stricly positive",
00323                    GSL_EINVAL);
00324 
00325     /* get sizes */
00326     n = gnn_dataset_input_get_size (data);
00327     m = gnn_dataset_output_get_size (data);
00328     l = gnn_node_param_get_size (node);
00329 
00330     /* set fields */
00331     trainer->type    = strdup (type);
00332     trainer->node    = node;
00333     trainer->crit    = crit;
00334     trainer->data    = data;
00335     trainer->n       = 1;
00336     trainer->s       = 0;
00337     trainer->epoch   = 1;
00338     trainer->sume    = 0.0;
00339     trainer->sump    = 0.0;
00340     trainer->grad    = gnn_grad_new (node, crit, data);
00341     trainer->reset   = gnn_trainer_default_reset;
00342     trainer->train   = NULL;
00343     trainer->destroy = gnn_trainer_default_destroy;
00344 
00345     /* install functions */
00346     if (reset != NULL)
00347         trainer->reset = reset;
00348     if (train != NULL)
00349         trainer->train = train;
00350     if (destroy != NULL)
00351         trainer->destroy = destroy;
00352 
00353     return 0;
00354 }
00355 
00356 /**
00357  * @brief Destroys the trainer.
00358  * @ingroup gnn_trainer_doc
00359  *
00360  * This function destroys the \ref gnn_trainer by calling the installed
00361  * destructor function for the trainer type and by freeing the
00362  * \ref gnn_trainer's structure.
00363  *
00364  * It doesn't destroy the node, nor the criterion, nor the dataset. They
00365  * should be destroyed independently.
00366  *
00367  * @param  trainer A pointer to a \ref gnn_trainer.
00368  */
00369 void
00370 gnn_trainer_destroy (gnn_trainer *trainer)
00371 {
00372     assert (trainer != NULL);
00373     
00374     /* destroy type-specific data */
00375     trainer->destroy (trainer);
00376     
00377     /* free structure */
00378     if (trainer->grad != NULL)
00379         gnn_grad_destroy (trainer->grad);
00380     if (trainer->type != NULL)
00381         free ((char *) trainer->type);
00382     free (trainer);
00383 }
00384 
00385 /**
00386  * @brief Resets the trainer.
00387  * @ingroup gnn_trainer_doc
00388  *
00389  * This function resets the trainer, i.e. the error is cleared, the epoch
00390  * reset, the pattern iterator set to the first pattern, etc.
00391  *
00392  * @param  trainer A pointer to a \ref gnn_trainer.
00393  * @return Returns 0 if suceeded.
00394  */
00395 int
00396 gnn_trainer_reset (gnn_trainer *trainer)
00397 {
00398     assert (trainer != NULL);
00399     
00400     /* reset trainer */
00401     trainer->reset (trainer);
00402     
00403     /* reset common fields */
00404     trainer->s     = 0;
00405     trainer->epoch = 1;
00406     trainer->sume  = 0.0;
00407     trainer->sump  = 0.0;
00408     
00409     return 0;
00410 }
00411 
00412 /**
00413  * @brief Trains the node.
00414  * @ingroup gnn_trainer_doc
00415  *
00416  * This function executes a full iteration of the training algorithm on the
00417  * node. In this context, an iteration consists of presenting a batch of
00418  * patterns to the node and adjust its parameters accordingly.
00419  *
00420  * Every call to this function presents the next patterns until the end
00421  * is reached. The last batch will be of smaller if the number of total
00422  * patterns is not a multiple of the size of the batches. If the end is
00423  * reached, then the trainer goes to the next epoch and the dataset is
00424  * automatically reset.
00425  *
00426  * @param  trainer A pointer to a \ref gnn_trainer.
00427  * @return Returns the mean value of the (by the criterion) computed cost for
00428  *         the current epoch.
00429  */
00430 double
00431 gnn_trainer_train (gnn_trainer *trainer)
00432 {
00433     double cost;
00434 
00435     assert (trainer != NULL);
00436 
00437     /* train */
00438     cost = trainer->train (trainer);
00439 
00440     return cost;
00441 }
00442 
00443 /**
00444  * @brief Get the dataset.
00445  * @ingroup gnn_trainer_doc
00446  *
00447  * This function returns a pointer to the dataset that the trainer uses for
00448  * training.
00449  *
00450  * @param  trainer A pointer to a \ref gnn_trainer.
00451  * @return Returns a pointer to a \ref gnn_dataset.
00452  */
00453 gnn_dataset *
00454 gnn_trainer_get_dataset (gnn_trainer *trainer)
00455 {
00456     assert (trainer != NULL);
00457     
00458     return trainer->data;
00459 }
00460 
00461 /**
00462  * @brief Get the model.
00463  * @ingroup gnn_trainer_doc
00464  *
00465  * This function returns a pointer to the model that the trainer trains.
00466  *
00467  * @param  trainer A pointer to a \ref gnn_trainer.
00468  * @return Returns a pointer to a \ref gnn_node.
00469  */
00470 gnn_node *
00471 gnn_trainer_get_node (gnn_trainer *trainer)
00472 {
00473     assert (trainer != NULL);
00474     
00475     return trainer->node;
00476 }
00477 
00478 /**
00479  * @brief Get the criterion.
00480  * @ingroup gnn_trainer_doc
00481  *
00482  * This function returns a pointer to the criterion that the trainer uses for
00483  * training.
00484  *
00485  * @param  trainer A pointer to a \ref gnn_trainer.
00486  * @return Returns a pointer to a \ref gnn_criterion.
00487  */
00488 gnn_criterion *
00489 gnn_trainer_get_criterion (gnn_trainer *trainer)
00490 {
00491     assert (trainer != NULL);
00492     
00493     return trainer->crit;
00494 }
00495 
00496 /**
00497  * @brief Set the size for minibatches.
00498  * @ingroup gnn_trainer_doc
00499  *
00500  * This function sets the size of the minibatches that should be processed
00501  * upon a call of \ref gnn_trainer_train. This size should be stricly positive.
00502  * Although it can be greather than the number of available patterns in the
00503  * dataset, the effective minibatch will be smaller.
00504  *
00505  * @param  trainer A pointer to a \ref gnn_trainer.
00506  * @param  size    The size for minibatches.
00507  * @return Returns 0 if succeeded.
00508  */
00509 int
00510 gnn_trainer_batch_set_size (gnn_trainer *trainer, size_t size)
00511 {
00512     assert (trainer != NULL);
00513     
00514     /* check batch size */
00515     if (size < 1)
00516     {
00517         GSL_ERROR ("the size of the batch should be stricly positive",
00518                    GSL_EINVAL);
00519     }
00520     
00521     /* set batchsize */
00522     trainer->n = size;
00523     
00524     return 0;
00525 }
00526 
00527 /**
00528  * @brief Get the size of the minibatches.
00529  * @ingroup gnn_trainer_doc
00530  *
00531  * This function returns a strictly positive number corresponding to the
00532  * size of the minibatches that are processed upon a call of
00533  * \ref gnn_trainer_train.
00534  *
00535  * @param  trainer A pointer to a \ref gnn_trainer.
00536  * @return Returns the size of the minibatches.
00537  */
00538 size_t
00539 gnn_trainer_batch_get_size (gnn_trainer *trainer)
00540 {
00541     assert (trainer != NULL);
00542     
00543     return trainer->n;
00544 }
00545 
00546 /**
00547  * @brief Returns the index of the next minibatch.
00548  * @ingroup gnn_trainer_doc
00549  *
00550  * This function returns the index of the first pattern in the minibatch
00551  * to be processed.
00552  *
00553  * @param  trainer A pointer to a \ref gnn_trainer.
00554  * @return Returns the index of the next pattern.
00555  */
00556 size_t
00557 gnn_trainer_get_pattern_index (gnn_trainer *trainer)
00558 {
00559     assert (trainer != NULL);
00560     
00561     return trainer->s;
00562 }
00563 
00564 /**
00565  * @brief Get the number of the current epoch.
00566  * @ingroup gnn_trainer_doc
00567  *
00568  * This function returns the number of the current epoch.
00569  *
00570  * @param  trainer A pointer to a \ref gnn_trainer.
00571  * @return Returns the number of the current epoch.
00572  */
00573 size_t
00574 gnn_trainer_get_epoch (gnn_trainer *trainer)
00575 {
00576     assert (trainer != NULL);
00577     
00578     return trainer->epoch;
00579 }
00580 
00581 /**
00582  * @brief Get the mean cost.
00583  * @ingroup gnn_trainer_doc
00584  *
00585  * This function returns the mean cost of the patterns presented in the current
00586  * epoch. That is, if the patterns \f$ x^1, \ldots, x^{k'} \f$ have been
00587  * evaluated in the current epoch, then the value
00588  * \f[ <E>_{k'} = \frac{1}{\sum_{k=1}^{k'} p_k}
00589  *                     \sum_{k=1}^{k'} p_k E(y_k, t_k) \f]
00590  * where \f$p_k\f$ is the weight of the \f$k\f$-th pattern, will be returned.
00591  *
00592  * It is important to note that this value considers also the last evaluated
00593  * minibatch's cost.
00594  *
00595  * @param  trainer A pointer to a \ref gnn_trainer.
00596  * @return Returns the weighted sum \f$<E>_{k'}\f$.
00597  */
00598 double
00599 gnn_trainer_get_epoch_cost (gnn_trainer *trainer)
00600 {
00601     double Se;
00602     double Sp;
00603     
00604     assert (trainer != NULL);
00605 
00606     /* compute cost sum and weight sum */
00607     Se = trainer->sume + GNN_GRAD_SUME (trainer->grad);
00608     Sp = trainer->sump + GNN_GRAD_SUMP (trainer->grad);
00609 
00610     return Se / Sp;
00611 }
00612 
00613 /**
00614  * @brief Returns a pointer to the internal \ref gnn_grad buffer.
00615  * @ingroup gnn_trainer_doc
00616  *
00617  * This function returns a pointer to the internal error and gradients
00618  * evaluation buffer.
00619  *
00620  * @param  trainer A pointer to a \ref gnn_trainer.
00621  * @return Returns a pointer to the \ref gnn_grad.
00622  */
00623 gnn_grad *
00624 gnn_trainer_batch_get_grad (gnn_trainer *trainer)
00625 {
00626     return trainer->grad;
00627 }
00628 
00629 /**
00630  * @brief Returns the last evaluated batch's mean error.
00631  * @ingroup gnn_trainer_doc
00632  *
00633  * This function returns a the last evaluated batch's mean error \f$<E>\f$.
00634  *
00635  * @param  trainer A pointer to a \ref gnn_trainer.
00636  * @return Returns \f$<E>\f$.
00637  */
00638 double
00639 gnn_trainer_batch_get_e (gnn_trainer *trainer)
00640 {
00641     return GNN_GRAD_E (trainer->grad);
00642 }
00643 
00644 /**
00645  * @brief Returns the last batch's evaluated gradient dx.
00646  * @ingroup gnn_trainer_doc
00647  *
00648  * This function returns a the last evaluated batch's mean gradient with
00649  * respect to its inputs \f$<\frac{\partial E}{\partial x}>\f$.
00650  *
00651  * The returned vector can be freely accessed and its values modified, but
00652  * it should not be freed.
00653  *
00654  * @param  trainer A pointer to a \ref gnn_trainer.
00655  * @return Returns \f$<\frac{\partial E}{\partial x}>\f$.
00656  */
00657 gsl_vector *
00658 gnn_trainer_batch_get_dx (gnn_trainer *trainer)
00659 {
00660     return GNN_GRAD_DX (trainer->grad);
00661 }
00662 
00663 /**
00664  * @brief Returns the last batch's evaluated gradient dw.
00665  * @ingroup gnn_trainer_doc
00666  *
00667  * This function returns a the last evaluated batch's mean gradient with
00668  * respect to its parameters \f$<\frac{\partial E}{\partial w}>\f$.
00669  *
00670  * The returned vector can be freely accessed and its values modified, but
00671  * it should not be freed.
00672  *
00673  * @param  trainer A pointer to a \ref gnn_trainer.
00674  * @return Returns \f$<\frac{\partial E}{\partial w}>\f$.
00675  */
00676 gsl_vector *
00677 gnn_trainer_batch_get_dw (gnn_trainer *trainer)
00678 {
00679     return GNN_GRAD_DW (trainer->grad);
00680 }
00681 
00682 /**
00683  * @brief Processes a batch.
00684  * @ingroup gnn_trainer_doc
00685  *
00686  * This function processes the current minibatch of patterns, computing
00687  * the batch's mean cost and its gradients \f$\frac{\partial E}{\partial x}\f$
00688  * and \f$\frac{\partial E}{\partial w}\f$. They can be retrived afterthen
00689  * by invoking
00690  *
00691  * - \ref gnn_trainer_batch_get_e
00692  * - \ref gnn_trainer_batch_get_dx
00693  * - \ref gnn_trainer_batch_get_dw
00694  * - \ref gnn_trainer_get_epoch_cost
00695  *
00696  * Theese values are obtained by averaging over the batch's pattern.
00697  *
00698  * Several calls of this function will process the same minibatch. Only
00699  * a call of \ref gnn_trainer_batch_next moves onto the next one.
00700  *
00701  * This function is very handy when implementing your own trainers. It
00702  * provides an easy way to obtain the gradient and handles the necessary
00703  * field updates on the \ref gnn_trainer. Please refer to
00704  * \ref gnn_trainer_train_type for details.
00705  *
00706  * @param  trainer A pointer to a \ref gnn_trainer.
00707  * @return Returns 0 if succeeded.
00708  */
00709 int
00710 gnn_trainer_batch_process (gnn_trainer *trainer)
00711 {
00712     assert (trainer != NULL);
00713     
00714     /* process batch */
00715     gnn_grad_pats (trainer->grad, gnnGradDw, trainer->s, trainer->n);
00716 
00717     return 0;
00718 }
00719 
00720 /**
00721  * @brief Moves onto the next batch.
00722  * @ingroup gnn_trainer_doc
00723  *
00724  * This function moves to the next batch. That is, if sets the trainers
00725  * internal state to prepare to process the next batch. If the end of the
00726  * dataset is reached, then the next epoch is initiated, the mean cost
00727  * info cleared, the dataset reset.
00728  *
00729  * This function is very handy when implementing your own trainers. It
00730  * should be used in conjunction with \ref gnn_trainer_batch_process.
00731  *
00732  * @param  trainer A pointer to a \ref gnn_trainer.
00733  * @return Returns 0 if suceeded.
00734  */
00735 int
00736 gnn_trainer_batch_next (gnn_trainer *trainer)
00737 {
00738     assert (trainer != NULL);
00739     
00740     /* update current pattern index */
00741     trainer->s += trainer->n;
00742 
00743     /* update mean cost values */
00744     trainer->sume += GNN_GRAD_SUME (trainer->grad);
00745     trainer->sump += GNN_GRAD_SUMP (trainer->grad);
00746 
00747     /* if end reached, advance to next epoch */
00748     if (trainer->s >= gnn_dataset_get_size (trainer->data))
00749     {
00750         trainer->epoch++;
00751         trainer->s    = 0;
00752         trainer->sume = 0.0;
00753         trainer->sump = 0.0;
00754         gnn_dataset_reset (trainer->data);
00755     }
00756     
00757     return 0;
00758 }
00759 
00760 
00761 

Generated on Sun Jun 13 20:50:12 2004 for libgnn Gradient Retropropagation Machine Library by doxygen1.2.18