00001 /*************************************************************************** 00002 * @file gnn_trainer.h 00003 * @brief Trainer Header File. 00004 * 00005 * @date : 29-08-03 21:46 00006 * @author : Pedro Ortega C. <peortega@dcc.uchile.cl> 00007 * Copyright 2003 Pedro Ortega C. 00008 ****************************************************************************/ 00009 /* 00010 * This program is free software; you can redistribute it and/or modify 00011 * it under the terms of the GNU General Public License as published by 00012 * the Free Software Foundation; either version 2 of the License, or 00013 * (at your option) any later version. 00014 * 00015 * This program is distributed in the hope that it will be useful, 00016 * but WITHOUT ANY WARRANTY; without even the implied warranty of 00017 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00018 * GNU Library General Public License for more details. 00019 * 00020 * You should have received a copy of the GNU General Public License 00021 * along with this program; if not, write to the Free Software 00022 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. 00023 */ 00024 00025 #ifndef _GNN_TRAINER_H_ 00026 #define _GNN_TRAINER_H_ 00027 00028 00029 00030 /******************************************/ 00031 /* Include Files */ 00032 /******************************************/ 00033 00034 #include "gnn_evaluation.h" 00035 00036 00037 00038 /******************************************/ 00039 /* Typedefs */ 00040 /******************************************/ 00041 00042 /** 00043 * @brief The datatype for trainer reset functions. 00044 * @ingroup gnn_trainer_doc 00045 * 00046 * This is the datatype that contains the basic elements for the implementation 00047 * of a trainer. Every new type of trainer should extend this structure. 00048 */ 00049 typedef struct _gnn_trainer gnn_trainer; 00050 00051 /** 00052 * @brief The datatype for trainer reset functions. 00053 * @ingroup gnn_trainer_doc 00054 * 00055 * This is the datatype for the \ref gnn_trainer's "reset" function. These 00056 * functions are called by the \ref gnn_trainer_reset function. They should 00057 * reset the trainer's internal state, so that a new training procedure 00058 * can begin. 00059 */ 00060 typedef int (*gnn_trainer_reset_type) (gnn_trainer *trainer); 00061 00062 /** 00063 * @brief The datatype for trainer train functions. 00064 * @ingroup gnn_trainer_doc 00065 * 00066 * This function returns the sum of the patterns' costs. 00067 * 00068 * This is the datatype for the \ref gnn_trainer's "train" function. This 00069 * function is called by the \ref gnn_trainer_train function. Basically, 00070 * it should update the model's parameters in order to minimize the 00071 * criterion, using a minibatch of patterns, starting at "currpat". In order 00072 * to simplify the implementation, \ref libgnn provides the 00073 * \ref gnn_trainer_batch_process function, which processes a minibatch: 00074 * it returns mean cost of the so far evaluated patterns and the estimation 00075 * \f[ \frac{\partial \hat{E}}{\partial w} = 00076 * \frac{ \sum_k p_k 00077 * { \left . \frac{\partial E}{\partial w} \right | }_{x_k} 00078 * } 00079 * { \sum_k p_k } \f] 00080 * of the gradient, where the sums are taken for the patterns in the minibatch. 00081 * You can then use the returned gradient to update the model's parameters. 00082 * After that, don't forget to return the sum of the costs (which is returned 00083 * by \ref gnn_trainer_batch_process). 00084 */ 00085 typedef int (*gnn_trainer_train_type) (gnn_trainer *trainer); 00086 00087 /** 00088 * @brief The datatype for trainer destroy functions. 00089 * @ingroup gnn_trainer_doc 00090 * 00091 * This function is called by the \ref gnn_trainer_destroy function, and 00092 * should free all extra data that was allocated for the specific trainer 00093 * type. The \ref gnn_trainer data is handled automatically (and freed) by 00094 * \ref gnn_trainer_destroy, so only additional data should be freed 00095 * by implementations of this function type. 00096 */ 00097 typedef void (*gnn_trainer_destroy_type) (gnn_trainer *trainer); 00098 00099 struct _gnn_trainer 00100 { 00101 const char *type; /**< The type of the trainer. */ 00102 00103 gnn_node *node; /**< A pointer to the node to be trained. */ 00104 gnn_criterion *crit; /**< A pointer to the criterion. */ 00105 gnn_dataset *data; /**< A pointer to the dataset. */ 00106 00107 size_t n; /**< The size of the batches to be processed. */ 00108 size_t s; /**< The number of the next pattern to be evaluated. */ 00109 size_t epoch; /**< The number of the current epoch. */ 00110 double sume; /**< The sum of the epochs evaluated criterions. */ 00111 double sump; /**< The sum of the patterns' weights viewed so far. */ 00112 00113 gnn_grad *grad; /**< The internal batch gradients evaluation buffers.*/ 00114 00115 gnn_trainer_reset_type reset; /**< The "reset" function pointer. */ 00116 gnn_trainer_train_type train; /**< The "train" function pointer. */ 00117 gnn_trainer_destroy_type destroy; /**< The "destroy" function pointer. */ 00118 }; 00119 00120 00121 00122 /******************************************/ 00123 /* Public Interface */ 00124 /******************************************/ 00125 00126 const char * 00127 gnn_trainer_get_type (gnn_trainer *trainer); 00128 00129 int 00130 gnn_trainer_init (gnn_trainer *trainer, 00131 const char *type, 00132 gnn_node *node, 00133 gnn_criterion *crit, 00134 gnn_dataset *data, 00135 gnn_trainer_reset_type reset, 00136 gnn_trainer_train_type train, 00137 gnn_trainer_destroy_type destroy); 00138 00139 void 00140 gnn_trainer_destroy (gnn_trainer *trainer); 00141 00142 int 00143 gnn_trainer_reset (gnn_trainer *trainer); 00144 00145 double 00146 gnn_trainer_train (gnn_trainer *trainer); 00147 00148 00149 gnn_dataset * 00150 gnn_trainer_get_dataset (gnn_trainer *trainer); 00151 00152 gnn_node * 00153 gnn_trainer_get_node (gnn_trainer *trainer); 00154 00155 gnn_criterion * 00156 gnn_trainer_get_criterion (gnn_trainer *trainer); 00157 00158 00159 int 00160 gnn_trainer_batch_set_size (gnn_trainer *trainer, size_t size); 00161 00162 size_t 00163 gnn_trainer_batch_get_size (gnn_trainer *trainer); 00164 00165 00166 size_t 00167 gnn_trainer_get_pattern_index (gnn_trainer *trainer); 00168 00169 size_t 00170 gnn_trainer_get_epoch (gnn_trainer *trainer); 00171 00172 double 00173 gnn_trainer_get_epoch_cost (gnn_trainer *trainer); 00174 00175 gnn_grad * 00176 gnn_trainer_batch_get_grad (gnn_trainer *trainer); 00177 00178 double 00179 gnn_trainer_batch_get_e (gnn_trainer *trainer); 00180 00181 gsl_vector * 00182 gnn_trainer_batch_get_dx (gnn_trainer *trainer); 00183 00184 gsl_vector * 00185 gnn_trainer_batch_get_dw (gnn_trainer *trainer); 00186 00187 int 00188 gnn_trainer_batch_process (gnn_trainer *trainer); 00189 00190 int 00191 gnn_trainer_batch_next (gnn_trainer *trainer); 00192 00193 00194 00195 #endif /* _GNN_TRAINER_H_ */ 00196 00197 00198
1.2.18