00001 /*************************************************************************** 00002 * @file gnn_trainer.c 00003 * @brief Trainer Implementation. 00004 * 00005 * @date : 29-08-03 21:45 00006 * @author : Pedro Ortega C. <peortega@dcc.uchile.cl> 00007 * Copyright 2003 Pedro Ortega C. 00008 ****************************************************************************/ 00009 /* 00010 * This program is free software; you can redistribute it and/or modify 00011 * it under the terms of the GNU General Public License as published by 00012 * the Free Software Foundation; either version 2 of the License, or 00013 * (at your option) any later version. 00014 * 00015 * This program is distributed in the hope that it will be useful, 00016 * but WITHOUT ANY WARRANTY; without even the implied warranty of 00017 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00018 * GNU Library General Public License for more details. 00019 * 00020 * You should have received a copy of the GNU General Public License 00021 * along with this program; if not, write to the Free Software 00022 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. 00023 */ 00024 00025 00026 00027 /** 00028 * @defgroup gnn_trainer_doc gnn_trainer : Trainers for Models. 00029 * @ingroup libgnn_trainer 00030 * @todo Provide an API for those who want to implement the training procedure 00031 * from scratch. 00032 * 00033 * The \ref gnn_trainer type defines a common interface for handling training 00034 * algorithms. It provides the basic structure on which all training algorithms 00035 * are based upon. Like all basic structures in \ref libgnn, the 00036 * \ref gnn_trainer structure cannot be used by itself. Instead, implementations 00037 * of trainers fill in the necessary parts to become a fully-functional 00038 * \ref gnn_trainer. 00039 * 00040 * What is a \ref gnn_trainer? A \ref gnn_trainer is an object that, given 00041 * a Model (in form of a \ref gnn_node), a criterion (\ref gnn_criterion) 00042 * and a dataset (\ref gnn_dataset): 00043 * - presents the input patterns sequentially to the Model, that is 00044 * \f$x^1, x^2, x^3, \ldots, x^N\f$, where \f$x^k\f$ denotes the \f$k\f$-th 00045 * input pattern, and \f$N\f$ is the order of the dataset, 00046 * - obtains its outputs \f$y^k = f(x^k)\f$ for \f$k=0,\ldots,N-1\f$, 00047 * - compares them with the target patterns using the criterion 00048 * - computes the gradient with respect to the parameters and 00049 * - adjusts them. 00050 * These steps are performed in order to optimize the Model. 00051 * 00052 * The patterns are presented in order, one after the other, until the last 00053 * pattern has been presented. After then, the dataset is reset and the 00054 * procedure is repeated. This is what is called an epoch: a cycle where 00055 * each pattern is evaluated once. 00056 * 00057 * Although one pattern is presented at a time, the needed gradient 00058 * \f$\frac{\partial E}{\partial w}\f$ is estimated by perfoming a weighted 00059 * sum over the particular gradients \f$\frac{\partial E}{\partial w}|_{x^k}\f$ 00060 * for each pattern \f$x^k\f$: 00061 * \f[ \frac{\partial \hat{E}}{\partial w} = 00062 * \frac{ \sum_k p_k 00063 * { \left . \frac{\partial E}{\partial w} \right | }_{x_k} 00064 * } 00065 * { \sum_k p_k } 00066 * \f] 00067 * where \f$p_k\f$ is the weight of the pattern \f$x^k\f$. How many patterns 00068 * are considered in this sum? Well, in the literature, when only one pattern 00069 * is considered, then the procedure is called on-line training, and in a 00070 * batch training, all patterns are taken. In \ref libgnn, this sum is taken 00071 * over so-called minibatches, which are simply groups of patterns of the same 00072 * size. E.g. if there a 10 patterns, and the minibatches are of size 4, then 00073 * the patterns are presented in the following groups: 00074 * \f[ \{x^1, x^2, x^3, x^4\}, \{x^5, x^6, x^7, x^8\}, \{x^9, x^{10}\} \f] 00075 * (Please note that the last minibatch has been cut down to fit.) 00076 * This operation is executed automatically by calling \ref gnn_trainer_train 00077 * on a \ref gnn_trainer. \ref gnn_trainer_train returns the mean cost 00078 * of all patterns presented to the model during the current epoch and adjusts 00079 * the patterns of the model. 00080 * 00081 * To reset a trainer and start over again (the parameters are kept), 00082 * call \ref gnn_trainer_reset. 00083 * 00084 * How to extend the \ref gnn_trainer structure to implement your own 00085 * trainer? In the same spirit as in all other basic structures, you can 00086 * use the \ref gnn_trainer structure itself or extend it like the following 00087 * hypothetical "my_trainer" structure: 00088 * \code 00089 * typedef struct _my_trainer 00090 * { 00091 * gnn_trainer trainer; 00092 * 00093 * // the specific data for "my_trainer" 00094 * ... 00095 * } my_trainer; 00096 * \endcode 00097 * This allows you to cast a "my_trainer" structure to a \ref gnn_trainer 00098 * structure whenever needed. 00099 * 00100 * Secondly, you have to provide at least four functions: a constructor, 00101 * which allocates and initializes a new "my_trainer" structure; a "reset" 00102 * function to reset the training procedure, which should be of type 00103 * \ref gnn_trainer_reset_type; a "train" function to train a minibatch 00104 * and update the parameters, which should be of type 00105 * \ref gnn_trainer_train_type; and a "destroy" function, which destroys 00106 * all additional data used by the "my_trainer" structure, of type 00107 * \ref gnn_trainer_destroy_type (alternatively, you can use the default 00108 * implementations for the "reset" and "destroy" functions). 00109 * 00110 * Please read the specific documentation for the function types that you 00111 * should provide: \ref gnn_trainer_reset_type and \ref gnn_trainer_train_type. 00112 */ 00113 00114 00115 00116 /******************************************/ 00117 /* Include Files */ 00118 /******************************************/ 00119 00120 #include <string.h> 00121 #include "gnn_trainer.h" 00122 #include "gnn_utilities.h" 00123 00124 00125 00126 /******************************************/ 00127 /* Static Declaration */ 00128 /******************************************/ 00129 00130 static int 00131 gnn_trainer_default_reset (gnn_trainer *trainer); 00132 00133 static void 00134 gnn_trainer_default_destroy (gnn_trainer *trainer); 00135 00136 00137 00138 /******************************************/ 00139 /* Static Implementation */ 00140 /******************************************/ 00141 00142 /** 00143 * @brief Default "reset" function for a trainer. 00144 * @ingroup gnn_trainer_doc 00145 * 00146 * This function is the default "reset" function for a trainer. It does nothing. 00147 * 00148 * @param trainer A pointer to a \ref gnn_trainer. 00149 * @return 0 if succeeded. 00150 */ 00151 static int 00152 gnn_trainer_default_reset (gnn_trainer *trainer) 00153 { 00154 assert (trainer != NULL); 00155 00156 return 0; 00157 } 00158 00159 /** 00160 * @brief Default "destroy" function for a trainer. 00161 * @ingroup gnn_trainer_doc 00162 * 00163 * This function is the default "destroy" function for a trainer. It assumes 00164 * that there isn't any additional data for the specific trainer type, so 00165 * it actually just returns. 00166 * 00167 * @param trainer A pointer to a \ref gnn_trainer. 00168 */ 00169 static void 00170 gnn_trainer_default_destroy (gnn_trainer *trainer) 00171 { 00172 return; 00173 } 00174 00175 00176 00177 /******************************************/ 00178 /* Public Interface */ 00179 /******************************************/ 00180 00181 /** 00182 * @brief Get the trainer's type. 00183 * @ingroup gnn_trainer_doc 00184 * 00185 * This function returns a string which contains the name of the trainer's 00186 * type. The string should not be modified. 00187 * 00188 * @param trainer A pointer to a \ref gnn_trainer. 00189 * @return Returns a pointer to the trainer's own type string. 00190 */ 00191 const char * 00192 gnn_trainer_get_type (gnn_trainer *trainer) 00193 { 00194 assert (trainer != NULL); 00195 00196 return trainer->type; 00197 } 00198 00199 /** 00200 * @brief Initializes a \ref gnn_trainer structure. 00201 * @ingroup gnn_trainer_doc 00202 * 00203 * This function initializes a \ref gnn_trainer structure by setting its 00204 * internal fields adecuately. Every extension to the \ref gnn_trainer 00205 * structure should have a constructor that calls this function to: 00206 * 00207 * - define the trainer's type, 00208 * - set the model to be trained (the node), 00209 * - set the criterion to be used, 00210 * - set the dataset with the patterns to be used during training, 00211 * - install the type-specific functions for "reset", "train" and "destroy" 00212 * 00213 * All arguments are mandatory, except for the "reset" and the "destroy" 00214 * functions: if the are specified as NULL, then the defaults are installed 00215 * instead (see \ref gnn_trainer_default_reset and 00216 * \ref gnn_trainer_default_destroy). 00217 * 00218 * The associated sizes should match. That is, the input pattern's sizes 00219 * contained in the dataset should corresponde to the node's input size. 00220 * Likewise, the output pattern's sizes should match with the node's output 00221 * size and the criterion's size. 00222 * 00223 * The following example should clarify its use. Suppose you want to create 00224 * a new trainer type, called "my_trainer". You have already implemented the 00225 * two of the needed functions, "my_trainer_reset", "my_trainer_train". 00226 * Actually, your special training function does not need any additional 00227 * data, and so the destructor isn't needed. Then, the following code could 00228 * be a possible implementation of the constructor: 00229 * \code 00230 gnn_trainer * 00231 my_trainer_new (gnn_node *model, gnn_criterion *crit, gnn_dataset *data) 00232 { 00233 int status; 00234 my_trainer *t; // this is just an extension of gnn_trainer 00235 gnn_trainer *tview; // points to the same structure, but uses another 00236 // view 00237 00238 // alloc space 00239 t = (my_trainer *) malloc (sizeof (*t)); 00240 if (t == NULL) 00241 print_error_message_and_return_error_code (); 00242 00243 // cast to a gnn_trainer 00244 tview = (gnn_trainer *) t; 00245 00246 // initialize 00247 status = gnn_trainer_init (tview, 00248 "my_trainer", 00249 model, 00250 crit, 00251 data, 00252 my_trainer_reset, 00253 my_trainer_reset, 00254 NULL); 00255 if (status) 00256 print_error_message_and_return_error_code (); 00257 00258 // finish initialization 00259 ... 00260 00261 return tview; 00262 } 00263 * \endcode 00264 * 00265 * @param trainer A pointer to a \ref gnn_trainer. 00266 * @param type A string to the typename. It will be duplicated. 00267 * @param node A pointer to a \ref gnn_node. 00268 * @param crit A pointer to a \ref gnn_criterion. 00269 * @param data A pointer to a \ref gnn_data. 00270 * @param reset A pointer to the "reset" function (or NULL). 00271 * @param train A pointer to the "train" function. 00272 * @param destroy A pointer to the "destroy" function (or NULL). 00273 * @return Returns 0 if initialization succeeded. 00274 */ 00275 int 00276 gnn_trainer_init (gnn_trainer *trainer, 00277 const char *type, 00278 gnn_node *node, 00279 gnn_criterion *crit, 00280 gnn_dataset *data, 00281 gnn_trainer_reset_type reset, 00282 gnn_trainer_train_type train, 00283 gnn_trainer_destroy_type destroy) 00284 { 00285 size_t n; 00286 size_t m; 00287 size_t l; 00288 00289 assert (trainer != NULL); 00290 00291 /* check arguments */ 00292 if (type == NULL) 00293 GSL_ERROR ("invalid trainer type", GSL_EINVAL); 00294 if (node == NULL) 00295 GSL_ERROR ("invalid node", GSL_EINVAL); 00296 if (crit == NULL) 00297 GSL_ERROR ("invalid criterion", GSL_EINVAL); 00298 if (data == NULL) 00299 GSL_ERROR ("invalid dataset", GSL_EINVAL); 00300 if (train == NULL) 00301 GSL_ERROR ("invalid train function", GSL_EINVAL); 00302 00303 /* check sizes */ 00304 if (gnn_dataset_input_get_size (data) != gnn_node_input_get_size (node)) 00305 GSL_ERROR ("the input pattern's size does not match with the node's " 00306 "input size", GSL_EINVAL); 00307 if (gnn_dataset_output_get_size (data) != gnn_node_output_get_size (node)) 00308 GSL_ERROR ("the output pattern's size does not match with the node's " 00309 "output size", GSL_EINVAL); 00310 if (gnn_dataset_output_get_size (data) != gnn_criterion_get_size (crit)) 00311 GSL_ERROR ("the output pattern's sizes does not match with the " 00312 "criterion's size", GSL_EINVAL); 00313 00314 /* check if sizes are stricly positive */ 00315 if (gnn_node_input_get_size (node) <= 0) 00316 GSL_ERROR ("the input size should be stricly positive", 00317 GSL_EINVAL); 00318 if (gnn_node_output_get_size (node) <= 0) 00319 GSL_ERROR ("the output size should be stricly positive", 00320 GSL_EINVAL); 00321 if (gnn_node_param_get_size (node) <= 0) 00322 GSL_ERROR ("the parameter size should be stricly positive", 00323 GSL_EINVAL); 00324 00325 /* get sizes */ 00326 n = gnn_dataset_input_get_size (data); 00327 m = gnn_dataset_output_get_size (data); 00328 l = gnn_node_param_get_size (node); 00329 00330 /* set fields */ 00331 trainer->type = strdup (type); 00332 trainer->node = node; 00333 trainer->crit = crit; 00334 trainer->data = data; 00335 trainer->n = 1; 00336 trainer->s = 0; 00337 trainer->epoch = 1; 00338 trainer->sume = 0.0; 00339 trainer->sump = 0.0; 00340 trainer->grad = gnn_grad_new (node, crit, data); 00341 trainer->reset = gnn_trainer_default_reset; 00342 trainer->train = NULL; 00343 trainer->destroy = gnn_trainer_default_destroy; 00344 00345 /* install functions */ 00346 if (reset != NULL) 00347 trainer->reset = reset; 00348 if (train != NULL) 00349 trainer->train = train; 00350 if (destroy != NULL) 00351 trainer->destroy = destroy; 00352 00353 return 0; 00354 } 00355 00356 /** 00357 * @brief Destroys the trainer. 00358 * @ingroup gnn_trainer_doc 00359 * 00360 * This function destroys the \ref gnn_trainer by calling the installed 00361 * destructor function for the trainer type and by freeing the 00362 * \ref gnn_trainer's structure. 00363 * 00364 * It doesn't destroy the node, nor the criterion, nor the dataset. They 00365 * should be destroyed independently. 00366 * 00367 * @param trainer A pointer to a \ref gnn_trainer. 00368 */ 00369 void 00370 gnn_trainer_destroy (gnn_trainer *trainer) 00371 { 00372 assert (trainer != NULL); 00373 00374 /* destroy type-specific data */ 00375 trainer->destroy (trainer); 00376 00377 /* free structure */ 00378 if (trainer->grad != NULL) 00379 gnn_grad_destroy (trainer->grad); 00380 if (trainer->type != NULL) 00381 free ((char *) trainer->type); 00382 free (trainer); 00383 } 00384 00385 /** 00386 * @brief Resets the trainer. 00387 * @ingroup gnn_trainer_doc 00388 * 00389 * This function resets the trainer, i.e. the error is cleared, the epoch 00390 * reset, the pattern iterator set to the first pattern, etc. 00391 * 00392 * @param trainer A pointer to a \ref gnn_trainer. 00393 * @return Returns 0 if suceeded. 00394 */ 00395 int 00396 gnn_trainer_reset (gnn_trainer *trainer) 00397 { 00398 assert (trainer != NULL); 00399 00400 /* reset trainer */ 00401 trainer->reset (trainer); 00402 00403 /* reset common fields */ 00404 trainer->s = 0; 00405 trainer->epoch = 1; 00406 trainer->sume = 0.0; 00407 trainer->sump = 0.0; 00408 00409 return 0; 00410 } 00411 00412 /** 00413 * @brief Trains the node. 00414 * @ingroup gnn_trainer_doc 00415 * 00416 * This function executes a full iteration of the training algorithm on the 00417 * node. In this context, an iteration consists of presenting a batch of 00418 * patterns to the node and adjust its parameters accordingly. 00419 * 00420 * Every call to this function presents the next patterns until the end 00421 * is reached. The last batch will be of smaller if the number of total 00422 * patterns is not a multiple of the size of the batches. If the end is 00423 * reached, then the trainer goes to the next epoch and the dataset is 00424 * automatically reset. 00425 * 00426 * @param trainer A pointer to a \ref gnn_trainer. 00427 * @return Returns the mean value of the (by the criterion) computed cost for 00428 * the current epoch. 00429 */ 00430 double 00431 gnn_trainer_train (gnn_trainer *trainer) 00432 { 00433 double cost; 00434 00435 assert (trainer != NULL); 00436 00437 /* train */ 00438 cost = trainer->train (trainer); 00439 00440 return cost; 00441 } 00442 00443 /** 00444 * @brief Get the dataset. 00445 * @ingroup gnn_trainer_doc 00446 * 00447 * This function returns a pointer to the dataset that the trainer uses for 00448 * training. 00449 * 00450 * @param trainer A pointer to a \ref gnn_trainer. 00451 * @return Returns a pointer to a \ref gnn_dataset. 00452 */ 00453 gnn_dataset * 00454 gnn_trainer_get_dataset (gnn_trainer *trainer) 00455 { 00456 assert (trainer != NULL); 00457 00458 return trainer->data; 00459 } 00460 00461 /** 00462 * @brief Get the model. 00463 * @ingroup gnn_trainer_doc 00464 * 00465 * This function returns a pointer to the model that the trainer trains. 00466 * 00467 * @param trainer A pointer to a \ref gnn_trainer. 00468 * @return Returns a pointer to a \ref gnn_node. 00469 */ 00470 gnn_node * 00471 gnn_trainer_get_node (gnn_trainer *trainer) 00472 { 00473 assert (trainer != NULL); 00474 00475 return trainer->node; 00476 } 00477 00478 /** 00479 * @brief Get the criterion. 00480 * @ingroup gnn_trainer_doc 00481 * 00482 * This function returns a pointer to the criterion that the trainer uses for 00483 * training. 00484 * 00485 * @param trainer A pointer to a \ref gnn_trainer. 00486 * @return Returns a pointer to a \ref gnn_criterion. 00487 */ 00488 gnn_criterion * 00489 gnn_trainer_get_criterion (gnn_trainer *trainer) 00490 { 00491 assert (trainer != NULL); 00492 00493 return trainer->crit; 00494 } 00495 00496 /** 00497 * @brief Set the size for minibatches. 00498 * @ingroup gnn_trainer_doc 00499 * 00500 * This function sets the size of the minibatches that should be processed 00501 * upon a call of \ref gnn_trainer_train. This size should be stricly positive. 00502 * Although it can be greather than the number of available patterns in the 00503 * dataset, the effective minibatch will be smaller. 00504 * 00505 * @param trainer A pointer to a \ref gnn_trainer. 00506 * @param size The size for minibatches. 00507 * @return Returns 0 if succeeded. 00508 */ 00509 int 00510 gnn_trainer_batch_set_size (gnn_trainer *trainer, size_t size) 00511 { 00512 assert (trainer != NULL); 00513 00514 /* check batch size */ 00515 if (size < 1) 00516 { 00517 GSL_ERROR ("the size of the batch should be stricly positive", 00518 GSL_EINVAL); 00519 } 00520 00521 /* set batchsize */ 00522 trainer->n = size; 00523 00524 return 0; 00525 } 00526 00527 /** 00528 * @brief Get the size of the minibatches. 00529 * @ingroup gnn_trainer_doc 00530 * 00531 * This function returns a strictly positive number corresponding to the 00532 * size of the minibatches that are processed upon a call of 00533 * \ref gnn_trainer_train. 00534 * 00535 * @param trainer A pointer to a \ref gnn_trainer. 00536 * @return Returns the size of the minibatches. 00537 */ 00538 size_t 00539 gnn_trainer_batch_get_size (gnn_trainer *trainer) 00540 { 00541 assert (trainer != NULL); 00542 00543 return trainer->n; 00544 } 00545 00546 /** 00547 * @brief Returns the index of the next minibatch. 00548 * @ingroup gnn_trainer_doc 00549 * 00550 * This function returns the index of the first pattern in the minibatch 00551 * to be processed. 00552 * 00553 * @param trainer A pointer to a \ref gnn_trainer. 00554 * @return Returns the index of the next pattern. 00555 */ 00556 size_t 00557 gnn_trainer_get_pattern_index (gnn_trainer *trainer) 00558 { 00559 assert (trainer != NULL); 00560 00561 return trainer->s; 00562 } 00563 00564 /** 00565 * @brief Get the number of the current epoch. 00566 * @ingroup gnn_trainer_doc 00567 * 00568 * This function returns the number of the current epoch. 00569 * 00570 * @param trainer A pointer to a \ref gnn_trainer. 00571 * @return Returns the number of the current epoch. 00572 */ 00573 size_t 00574 gnn_trainer_get_epoch (gnn_trainer *trainer) 00575 { 00576 assert (trainer != NULL); 00577 00578 return trainer->epoch; 00579 } 00580 00581 /** 00582 * @brief Get the mean cost. 00583 * @ingroup gnn_trainer_doc 00584 * 00585 * This function returns the mean cost of the patterns presented in the current 00586 * epoch. That is, if the patterns \f$ x^1, \ldots, x^{k'} \f$ have been 00587 * evaluated in the current epoch, then the value 00588 * \f[ <E>_{k'} = \frac{1}{\sum_{k=1}^{k'} p_k} 00589 * \sum_{k=1}^{k'} p_k E(y_k, t_k) \f] 00590 * where \f$p_k\f$ is the weight of the \f$k\f$-th pattern, will be returned. 00591 * 00592 * It is important to note that this value considers also the last evaluated 00593 * minibatch's cost. 00594 * 00595 * @param trainer A pointer to a \ref gnn_trainer. 00596 * @return Returns the weighted sum \f$<E>_{k'}\f$. 00597 */ 00598 double 00599 gnn_trainer_get_epoch_cost (gnn_trainer *trainer) 00600 { 00601 double Se; 00602 double Sp; 00603 00604 assert (trainer != NULL); 00605 00606 /* compute cost sum and weight sum */ 00607 Se = trainer->sume + GNN_GRAD_SUME (trainer->grad); 00608 Sp = trainer->sump + GNN_GRAD_SUMP (trainer->grad); 00609 00610 return Se / Sp; 00611 } 00612 00613 /** 00614 * @brief Returns a pointer to the internal \ref gnn_grad buffer. 00615 * @ingroup gnn_trainer_doc 00616 * 00617 * This function returns a pointer to the internal error and gradients 00618 * evaluation buffer. 00619 * 00620 * @param trainer A pointer to a \ref gnn_trainer. 00621 * @return Returns a pointer to the \ref gnn_grad. 00622 */ 00623 gnn_grad * 00624 gnn_trainer_batch_get_grad (gnn_trainer *trainer) 00625 { 00626 return trainer->grad; 00627 } 00628 00629 /** 00630 * @brief Returns the last evaluated batch's mean error. 00631 * @ingroup gnn_trainer_doc 00632 * 00633 * This function returns a the last evaluated batch's mean error \f$<E>\f$. 00634 * 00635 * @param trainer A pointer to a \ref gnn_trainer. 00636 * @return Returns \f$<E>\f$. 00637 */ 00638 double 00639 gnn_trainer_batch_get_e (gnn_trainer *trainer) 00640 { 00641 return GNN_GRAD_E (trainer->grad); 00642 } 00643 00644 /** 00645 * @brief Returns the last batch's evaluated gradient dx. 00646 * @ingroup gnn_trainer_doc 00647 * 00648 * This function returns a the last evaluated batch's mean gradient with 00649 * respect to its inputs \f$<\frac{\partial E}{\partial x}>\f$. 00650 * 00651 * The returned vector can be freely accessed and its values modified, but 00652 * it should not be freed. 00653 * 00654 * @param trainer A pointer to a \ref gnn_trainer. 00655 * @return Returns \f$<\frac{\partial E}{\partial x}>\f$. 00656 */ 00657 gsl_vector * 00658 gnn_trainer_batch_get_dx (gnn_trainer *trainer) 00659 { 00660 return GNN_GRAD_DX (trainer->grad); 00661 } 00662 00663 /** 00664 * @brief Returns the last batch's evaluated gradient dw. 00665 * @ingroup gnn_trainer_doc 00666 * 00667 * This function returns a the last evaluated batch's mean gradient with 00668 * respect to its parameters \f$<\frac{\partial E}{\partial w}>\f$. 00669 * 00670 * The returned vector can be freely accessed and its values modified, but 00671 * it should not be freed. 00672 * 00673 * @param trainer A pointer to a \ref gnn_trainer. 00674 * @return Returns \f$<\frac{\partial E}{\partial w}>\f$. 00675 */ 00676 gsl_vector * 00677 gnn_trainer_batch_get_dw (gnn_trainer *trainer) 00678 { 00679 return GNN_GRAD_DW (trainer->grad); 00680 } 00681 00682 /** 00683 * @brief Processes a batch. 00684 * @ingroup gnn_trainer_doc 00685 * 00686 * This function processes the current minibatch of patterns, computing 00687 * the batch's mean cost and its gradients \f$\frac{\partial E}{\partial x}\f$ 00688 * and \f$\frac{\partial E}{\partial w}\f$. They can be retrived afterthen 00689 * by invoking 00690 * 00691 * - \ref gnn_trainer_batch_get_e 00692 * - \ref gnn_trainer_batch_get_dx 00693 * - \ref gnn_trainer_batch_get_dw 00694 * - \ref gnn_trainer_get_epoch_cost 00695 * 00696 * Theese values are obtained by averaging over the batch's pattern. 00697 * 00698 * Several calls of this function will process the same minibatch. Only 00699 * a call of \ref gnn_trainer_batch_next moves onto the next one. 00700 * 00701 * This function is very handy when implementing your own trainers. It 00702 * provides an easy way to obtain the gradient and handles the necessary 00703 * field updates on the \ref gnn_trainer. Please refer to 00704 * \ref gnn_trainer_train_type for details. 00705 * 00706 * @param trainer A pointer to a \ref gnn_trainer. 00707 * @return Returns 0 if succeeded. 00708 */ 00709 int 00710 gnn_trainer_batch_process (gnn_trainer *trainer) 00711 { 00712 assert (trainer != NULL); 00713 00714 /* process batch */ 00715 gnn_grad_pats (trainer->grad, gnnGradDw, trainer->s, trainer->n); 00716 00717 return 0; 00718 } 00719 00720 /** 00721 * @brief Moves onto the next batch. 00722 * @ingroup gnn_trainer_doc 00723 * 00724 * This function moves to the next batch. That is, if sets the trainers 00725 * internal state to prepare to process the next batch. If the end of the 00726 * dataset is reached, then the next epoch is initiated, the mean cost 00727 * info cleared, the dataset reset. 00728 * 00729 * This function is very handy when implementing your own trainers. It 00730 * should be used in conjunction with \ref gnn_trainer_batch_process. 00731 * 00732 * @param trainer A pointer to a \ref gnn_trainer. 00733 * @return Returns 0 if suceeded. 00734 */ 00735 int 00736 gnn_trainer_batch_next (gnn_trainer *trainer) 00737 { 00738 assert (trainer != NULL); 00739 00740 /* update current pattern index */ 00741 trainer->s += trainer->n; 00742 00743 /* update mean cost values */ 00744 trainer->sume += GNN_GRAD_SUME (trainer->grad); 00745 trainer->sump += GNN_GRAD_SUMP (trainer->grad); 00746 00747 /* if end reached, advance to next epoch */ 00748 if (trainer->s >= gnn_dataset_get_size (trainer->data)) 00749 { 00750 trainer->epoch++; 00751 trainer->s = 0; 00752 trainer->sume = 0.0; 00753 trainer->sump = 0.0; 00754 gnn_dataset_reset (trainer->data); 00755 } 00756 00757 return 0; 00758 } 00759 00760 00761
1.2.18