Main Page   Modules   Data Structures   File List   Data Fields   Globals   Related Pages  

gnn_trainer.h

Go to the documentation of this file.
00001 /***************************************************************************
00002  *  @file gnn_trainer.h
00003  *  @brief Trainer Header File.
00004  *
00005  *  @date   : 29-08-03 21:46
00006  *  @author : Pedro Ortega C. <peortega@dcc.uchile.cl>
00007  *  Copyright  2003  Pedro Ortega C.
00008  ****************************************************************************/
00009 /*
00010  *  This program is free software; you can redistribute it and/or modify
00011  *  it under the terms of the GNU General Public License as published by
00012  *  the Free Software Foundation; either version 2 of the License, or
00013  *  (at your option) any later version.
00014  *
00015  *  This program is distributed in the hope that it will be useful,
00016  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
00017  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00018  *  GNU Library General Public License for more details.
00019  *
00020  *  You should have received a copy of the GNU General Public License
00021  *  along with this program; if not, write to the Free Software
00022  *  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
00023  */
00024 
00025 #ifndef _GNN_TRAINER_H_
00026 #define _GNN_TRAINER_H_
00027 
00028 
00029 
00030 /******************************************/
00031 /* Include Files                          */
00032 /******************************************/
00033 
00034 #include "gnn_evaluation.h"
00035 
00036 
00037 
00038 /******************************************/
00039 /* Typedefs                               */
00040 /******************************************/
00041 
00042 /**
00043  * @brief The datatype for trainer reset functions.
00044  * @ingroup gnn_trainer_doc
00045  *
00046  * This is the datatype that contains the basic elements for the implementation
00047  * of a trainer. Every new type of trainer should extend this structure.
00048  */
00049 typedef struct _gnn_trainer gnn_trainer;
00050 
00051 /**
00052  * @brief The datatype for trainer reset functions.
00053  * @ingroup gnn_trainer_doc
00054  *
00055  * This is the datatype for the \ref gnn_trainer's "reset" function. These
00056  * functions are called by the \ref gnn_trainer_reset function. They should
00057  * reset the trainer's internal state, so that a new training procedure
00058  * can begin.
00059  */
00060 typedef int (*gnn_trainer_reset_type) (gnn_trainer *trainer);
00061 
00062 /**
00063  * @brief The datatype for trainer train functions.
00064  * @ingroup gnn_trainer_doc
00065  *
00066  * This function returns the sum of the patterns' costs.
00067  *
00068  * This is the datatype for the \ref gnn_trainer's "train" function. This
00069  * function is called by the \ref gnn_trainer_train function. Basically,
00070  * it should update the model's parameters in order to minimize the
00071  * criterion, using a minibatch of patterns, starting at "currpat". In order
00072  * to simplify the implementation, \ref libgnn provides the
00073  * \ref gnn_trainer_batch_process function, which processes a minibatch:
00074  * it returns mean cost of the so far evaluated patterns and the estimation
00075  * \f[ \frac{\partial \hat{E}}{\partial w} =
00076  *     \frac{ \sum_k p_k
00077  *            { \left . \frac{\partial E}{\partial w} \right | }_{x_k}
00078  *          }
00079  *          { \sum_k p_k } \f]
00080  * of the gradient, where the sums are taken for the patterns in the minibatch.
00081  * You can then use the returned gradient to update the model's parameters.
00082  * After that, don't forget to return the sum of the costs (which is returned
00083  * by \ref gnn_trainer_batch_process).
00084  */
00085 typedef int (*gnn_trainer_train_type) (gnn_trainer *trainer);
00086 
00087 /**
00088  * @brief The datatype for trainer destroy functions.
00089  * @ingroup gnn_trainer_doc
00090  *
00091  * This function is called by the \ref gnn_trainer_destroy function, and
00092  * should free all extra data that was allocated for the specific trainer
00093  * type. The \ref gnn_trainer data is handled automatically (and freed) by
00094  * \ref gnn_trainer_destroy, so only additional data should be freed
00095  * by implementations of this function type.
00096  */
00097 typedef void (*gnn_trainer_destroy_type) (gnn_trainer *trainer);
00098 
00099 struct _gnn_trainer
00100 {
00101     const char    *type; /**< The type of the trainer.                        */
00102     
00103     gnn_node      *node; /**< A pointer to the node to be trained.            */
00104     gnn_criterion *crit; /**< A pointer to the criterion.                     */
00105     gnn_dataset   *data; /**< A pointer to the dataset.                       */
00106     
00107     size_t n;            /**< The size of the batches to be processed.        */
00108     size_t s;            /**< The number of the next pattern to be evaluated. */
00109     size_t epoch;        /**< The number of the current epoch.                */
00110     double sume;         /**< The sum of the epochs evaluated criterions.     */
00111     double sump;         /**< The sum of the patterns' weights viewed so far. */
00112 
00113     gnn_grad *grad;      /**< The internal batch gradients evaluation buffers.*/
00114     
00115     gnn_trainer_reset_type   reset;      /**< The "reset" function pointer.   */
00116     gnn_trainer_train_type   train;      /**< The "train" function pointer.   */
00117     gnn_trainer_destroy_type destroy;    /**< The "destroy" function pointer. */
00118 };
00119 
00120 
00121 
00122 /******************************************/
00123 /* Public Interface                       */
00124 /******************************************/
00125 
00126 const char *
00127 gnn_trainer_get_type (gnn_trainer *trainer);
00128 
00129 int
00130 gnn_trainer_init (gnn_trainer             *trainer,
00131                   const char              *type,
00132                   gnn_node                *node,
00133                   gnn_criterion           *crit,
00134                   gnn_dataset             *data,
00135                   gnn_trainer_reset_type   reset,
00136                   gnn_trainer_train_type   train,
00137                   gnn_trainer_destroy_type destroy);
00138 
00139 void
00140 gnn_trainer_destroy (gnn_trainer *trainer);
00141 
00142 int
00143 gnn_trainer_reset (gnn_trainer *trainer);
00144 
00145 double
00146 gnn_trainer_train (gnn_trainer *trainer);
00147 
00148 
00149 gnn_dataset *
00150 gnn_trainer_get_dataset (gnn_trainer *trainer);
00151 
00152 gnn_node *
00153 gnn_trainer_get_node (gnn_trainer *trainer);
00154 
00155 gnn_criterion *
00156 gnn_trainer_get_criterion (gnn_trainer *trainer);
00157 
00158 
00159 int
00160 gnn_trainer_batch_set_size (gnn_trainer *trainer, size_t size);
00161 
00162 size_t
00163 gnn_trainer_batch_get_size (gnn_trainer *trainer);
00164 
00165 
00166 size_t
00167 gnn_trainer_get_pattern_index (gnn_trainer *trainer);
00168 
00169 size_t
00170 gnn_trainer_get_epoch (gnn_trainer *trainer);
00171 
00172 double
00173 gnn_trainer_get_epoch_cost (gnn_trainer *trainer);
00174 
00175 gnn_grad *
00176 gnn_trainer_batch_get_grad (gnn_trainer *trainer);
00177 
00178 double
00179 gnn_trainer_batch_get_e (gnn_trainer *trainer);
00180 
00181 gsl_vector *
00182 gnn_trainer_batch_get_dx (gnn_trainer *trainer);
00183 
00184 gsl_vector *
00185 gnn_trainer_batch_get_dw (gnn_trainer *trainer);
00186 
00187 int
00188 gnn_trainer_batch_process (gnn_trainer *trainer);
00189 
00190 int
00191 gnn_trainer_batch_next (gnn_trainer *trainer);
00192 
00193 
00194 
00195 #endif /* _GNN_TRAINER_H_ */
00196 
00197 
00198 

Generated on Sun Jun 13 20:50:12 2004 for libgnn Gradient Retropropagation Machine Library by doxygen1.2.18