Main Page   Modules   Data Structures   File List   Data Fields   Globals   Related Pages  

gnn_weight.c

Go to the documentation of this file.
00001 /***************************************************************************
00002  *  @file  gnn_weight.c
00003  *  @brief gnn_weight Implementation.
00004  *
00005  *  @date   : 15-07-03 19:56
00006  *  @author : Pedro Ortega C. <peortega@dcc.uchile.cl>
00007  *  Copyright  2003  Pedro Ortega C.
00008  ****************************************************************************/
00009 /*
00010  *  This program is free software; you can redistribute it and/or modify
00011  *  it under the terms of the GNU General Public License as published by
00012  *  the Free Software Foundation; either version 2 of the License, or
00013  *  (at your option) any later version.
00014  *
00015  *  This program is distributed in the hope that it will be useful,
00016  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
00017  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00018  *  GNU Library General Public License for more details.
00019  *
00020  *  You should have received a copy of the GNU General Public License
00021  *  along with this program; if not, write to the Free Software
00022  *  Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
00023  */
00024 
00025 /**
00026  * @defgroup gnn_weight_doc gnn_weight : Affine / Linear Transform
00027  * @ingroup gnn_atomic_doc
00028  * @todo Include more transformations (1-D and 2-D), and check other uses.
00029  *
00030  * The \ref gnn_weight node type implements the function given by
00031  * \f[ Y = A X + B \f]
00032  * where \f$A\f$ is a matrix of size \f$m \times n\f$ and \f$B\f$
00033  * is a vector of size \f$m \times 1\f$. Basically, the \ref gnn_weight
00034  * computes an affine transform. Using the nomenclature used in relation with
00035  * neural networks, the matrix \f$A\f$ is the "weights matrix" and \f$B\f$
00036  * is the "bias vector".
00037  *
00038  * The same function written in index notation is:
00039  * \f[ y_j = \sum_i A_{ji} x_i + B_{j} \f]
00040  * where the index \f$j\f$ runs over all outputs, i.e. \f$j=1,\ldots,m\f$
00041  * and \f$i\f$ runs over the inputs \f$i=1,\ldots,n\f$. The following figure
00042  * shows how the different terms affect the inputs:
00043  *
00044  * <img src="images/gnn_weight1.png">
00045  *
00046  * The \ref gnn_weight node is very versatile. Depending on the point of
00047  * view, it can be seen as carrying out many different types of functions.
00048  *
00049  * A new weight layer can be built using the \ref gnn_weight_new function.
00050  * There are several ways to initialize or set its values. These can be
00051  * set per example set randomly with \ref gnn_weight_init, or fixed with
00052  * alternative linear transforms like Cosine Transform, Hadamard or
00053  * Karhunen-Loeve. The matrix \f$A\f$ can be though as a graph's adyacency
00054  * matrix, where \f$A_{ji}\f$ corresponds to the connection strength of
00055  * the input \f$x_i\f$ with the output \f$y_j\f$. 
00056  */
00057 
00058 /******************************************/
00059 /* Include Files                          */
00060 /******************************************/
00061 
00062 #include <math.h>
00063 #include <gsl/gsl_blas.h>
00064 #include <gsl/gsl_randist.h>
00065 #include <gsl/gsl_eigen.h>
00066 #include "gnn_utilities.h"
00067 #include "gnn_weight.h"
00068 
00069 
00070 
00071 /******************************************/
00072 /* Static Declaration                     */
00073 /******************************************/
00074 
00075 /**
00076  * @brief The structure for a \ref gnn_weight node.
00077  * @ingroup gnn_weight_doc
00078  *
00079  * This datatype holds the information for a \ref gnn_weight node. Basically,
00080  * it extends the \ref gnn_node with special pointers to get fast accesses
00081  * to the matrices and vectors \f$A\f$, \f$B\f$,
00082  * \f$\frac{\partial E}{\partial A}\f$, \f$\frac{\partial E}{\partial B}\f$
00083  * and their frozen flags in order to improve the efficiency.
00084  */
00085 typedef struct _gnn_weight gnn_weight;
00086 
00087 struct _gnn_weight
00088 {
00089     gnn_node node;               /**< The underlying \ref gnn_node. */
00090 
00091     int linear;                  /**< No use yet. */
00092 
00093     gsl_matrix_view      A_view; /**< View of matrix A.             */
00094     gsl_vector_view      B_view; /**< View of vector B.             */
00095     gsl_matrix_view     dA_view; /**< View of matrix dA.            */
00096     gsl_vector_view     dB_view; /**< View of vector dB.            */
00097     gsl_matrix_int_view Af_view; /**< View of flags for A.          */
00098     gsl_vector_int_view Bf_view; /**< View of flags for B.          */
00099 
00100     gsl_matrix     *A;           /**< Pointer to matrix A.          */
00101     gsl_vector     *B;           /**< Pointer to vector B.          */
00102     gsl_matrix     *dA;          /**< Pointer to matrix dA.         */
00103     gsl_vector     *dB;          /**< Pointer to vector dB.         */
00104     gsl_matrix_int *Af;          /**< Pointer to flags for A.       */
00105     gsl_vector_int *Bf;          /**< Pointer to flags for B.       */
00106 
00107 };
00108 
00109 
00110 static int
00111 gnn_weight_f (gnn_node *node,
00112               const gsl_vector *x,
00113               const gsl_vector *w,
00114               gsl_vector *y);
00115 
00116 
00117 static int
00118 gnn_weight_dx (gnn_node *node,
00119                const gsl_vector *x,
00120                const gsl_vector *w,
00121                const gsl_vector *dy,
00122                gsl_vector *dx);
00123 
00124 static int
00125 gnn_weight_dw (gnn_node *node,
00126                const gsl_vector *x,
00127                const gsl_vector *w,
00128                const gsl_vector *dy,
00129                gsl_vector *dw);
00130 
00131 
00132 static void
00133 gnn_weight_destroy (gnn_node *node);
00134 
00135 
00136 
00137 static int
00138 gnn_weight_hadamard_recurse (gsl_matrix *A, double fact);
00139 
00140 
00141 
00142 /******************************************/
00143 /* Static Implementation                  */
00144 /******************************************/
00145 
00146 /**
00147  * @brief Computes the evaluation.
00148  * @ingroup gnn_weight_doc
00149  *
00150  * This functions computes:
00151  * \f[ y_j = \sum_{i=1}^n A_{ji} x_i + B_i \f]
00152  * or written in matrix notation:
00153  * \f[ Y = A x + B \f]
00154  *
00155  * @param node A pointer to a \ref gnn_weight node.
00156  * @param x    A pointer to the input vector \f$x\f$.
00157  * @param w    A pointer to the parameter vector \f$w\f$.
00158  * @param y    A pointer to the output vector \f$y\f$ where the result should
00159  *             be placed.
00160  * @return Returns 0 if suceeded.
00161  */
00162 static int
00163 gnn_weight_f (gnn_node *node,
00164               const gsl_vector *x,
00165               const gsl_vector *w,
00166               gsl_vector *y)
00167 {
00168     gnn_weight *wnode;
00169 
00170     wnode = (gnn_weight *) node;
00171 
00172     /* compute output */
00173     gsl_vector_add (y, wnode->B);
00174     gsl_blas_dgemv (CblasNoTrans, 1.0, wnode->A, x, 0.0, y);
00175 
00176     return 0;
00177 }
00178 
00179 /**
00180  * @brief Computes the \ref gnn_weight gradient with respect to its inputs.
00181  * @ingroup gnn_weight_doc
00182  *
00183  * This functions computes:
00184  * \f[\frac{\partial E}{\partial x_j} =
00185  *     \sum_j A_{ji} \frac{\partial E}{\partial y_j}\f]
00186  * or, written in matrix notation:
00187  * \f[ \frac{\partial E}{\partial x} =
00188  *     A^t \frac{\partial E}{\partial y} \f]
00189  *
00190  * @param node A pointer to a \ref gnn_weight node.
00191  * @param x    A pointer to the input vector \f$x\f$.
00192  * @param w    A pointer to the parameter vector \f$w\f$.
00193  * @param dy   A pointer to the vector \f$\frac{\partial E}{\partial y}\f$.
00194  * @param dx   A pointer to the vector \f$\frac{\partial E}{\partial x}\f$
00195  *             where the result should be placed.
00196  * @return Returns 0 if suceeded.
00197  */
00198 static int
00199 gnn_weight_dx (gnn_node *node,
00200                const gsl_vector *x,
00201                const gsl_vector *w,
00202                const gsl_vector *dy,
00203                gsl_vector *dx)
00204 {
00205     gnn_weight *wnode;
00206 
00207     wnode = (gnn_weight *) node;
00208 
00209     /* compute output */
00210     gsl_blas_dgemv (CblasTrans, 1.0, wnode->A, dy, 0.0, dx);
00211 
00212     return 0;
00213 }
00214 
00215 
00216 /**
00217  * @brief Computes the \ref gnn_weight gradient with respect to its parameters.
00218  * @ingroup gnn_weight_doc
00219  *
00220  * This functions computes:
00221  * \f[ \frac{\partial E}{\partial A_{ji}} =
00222  *     x_i \frac{\partial E}{\partial y_j} \f]
00223  * \f[ \frac{\partial E}{\partial B_i} =
00224  *     \frac{\partial E}{\partial y_j} \f]
00225  * where \f$x_0\f$ is defined as \f$1\f$. Or written in matrix notation:
00226  * \f[ \frac{\partial E}{\partial A} =
00227  *     \frac{\partial E}{\partial y}^t x \f]
00228  * and
00229  * \f[ \frac{\partial E}{\partial B} =
00230  *     \frac{\partial E}{\partial y}^t \f]
00231  *
00232  * @param node A pointer to a \ref gnn_weight node.
00233  * @param x    A pointer to the input vector \f$x\f$.
00234  * @param w    A pointer to the parameter vector \f$w\f$.
00235  * @param dy   A pointer to the vector \f$\frac{\partial E}{\partial y}\f$.
00236  * @param dw   A pointer to the vector \f$\frac{\partial E}{\partial w}\f$
00237  *             where the result should be placed.
00238  * @return Returns 0 if suceeded.
00239  */
00240 static int
00241 gnn_weight_dw (gnn_node *node,
00242                const gsl_vector *x,
00243                const gsl_vector *w,
00244                const gsl_vector *dy,
00245                gsl_vector *dw)
00246 {
00247     int input_size;
00248     int output_size;
00249     gsl_matrix_view X;
00250     gsl_matrix_view DY;
00251     gnn_weight *wnode;
00252 
00253     wnode = (gnn_weight *) node;
00254 
00255     /* get sizes */
00256     input_size  = gnn_node_input_get_size (node);
00257     output_size = gnn_node_output_get_size (node);
00258 
00259     /* get matrix and vector view */
00260     X  = gsl_matrix_view_vector ((gsl_vector *) x, 1, input_size);
00261     DY = gsl_matrix_view_vector ((gsl_vector *) dy, output_size, 1);
00262 
00263     /* compute output */
00264     gsl_blas_dgemm (CblasNoTrans, CblasNoTrans, 1.0, &(DY.matrix),
00265                                                 &(X.matrix), 1.0, wnode->dA);
00266     gsl_vector_add (wnode->dB, dy);
00267 
00268     return 0;
00269 }
00270 
00271 /**
00272  * @brief The destroy function for a \ref gnn_weight.
00273  * @ingroup gnn_weight_doc
00274  *
00275  * This functions computes is the \ref gnn_weight destroy function, which
00276  * destroys the type's specific data.
00277  *
00278  * @param node A pointer to a \ref gnn_weight.
00279  */
00280 static void
00281 gnn_weight_destroy (gnn_node *node)
00282 {
00283     return;
00284 }
00285 
00286 
00287 /******************************************/
00288 /* Public Implementation                  */
00289 /******************************************/
00290 
00291 /**
00292  * @brief Creates a new \ref gnn_weight layer.
00293  * @ingroup gnn_weight_doc
00294  *
00295  * This function creates a new \ref gnn_node of \ref gnn_weight type.
00296  *
00297  * @param input_size The node's input size.
00298  * @param output_size The node's output size.
00299  * @return A new node.
00300  */
00301 gnn_node *
00302 gnn_weight_new (int input_size, int output_size)
00303 {
00304     int status;
00305     int param_size;
00306     gnn_node   *node;
00307     gnn_weight *wnode;
00308     gsl_vector *w;
00309     gsl_vector *dw;
00310     gsl_vector_int *f;
00311     
00312     /* check if sizes are positive */
00313     if (input_size < 1 || output_size < 1)
00314     {
00315         GSL_ERROR_VAL ("gnn_weight's sizes should be stricly positive",
00316                        GSL_EINVAL, NULL);
00317     }
00318     
00319     /* compute amount of needed parameters */
00320     param_size = input_size * output_size + output_size;
00321 
00322     /* allocate the node */
00323     node = (gnn_node *) malloc (sizeof (gnn_weight));
00324     if (node == NULL)
00325         GSL_ERROR_VAL ("couldn't allocate memory for gnn_weight's node",
00326                        GSL_ENOMEM, NULL);
00327 
00328     /* Initialize the node */
00329     status = gnn_node_init (node,
00330                             "gnn_weight",
00331                             gnn_weight_f,
00332                             gnn_weight_dx,
00333                             gnn_weight_dw,
00334                             gnn_weight_destroy);
00335     if (status)
00336         GSL_ERROR_VAL ("could not initialize weights node", GSL_EINVAL, NULL);
00337         
00338     status = gnn_node_set_sizes (node, input_size, output_size, param_size);
00339     if (status)
00340         GSL_ERROR_VAL ("could not set sizes for weights node", GSL_EINVAL, NULL);
00341 
00342     /* get references for building views */
00343     wnode = (gnn_weight *) node;
00344     w     = gnn_node_local_get_w (node);
00345     dw    = gnn_node_local_get_dw (node);
00346     f     = gnn_node_local_get_f (node);
00347 
00348     /* build views */
00349     wnode->A_view  = gsl_matrix_view_vector (w, output_size, input_size);
00350     wnode->B_view  = gsl_vector_subvector (w, input_size * output_size,
00351                                                  output_size);
00352                                                  
00353     wnode->dA_view = gsl_matrix_view_vector (dw, output_size, input_size);
00354     wnode->dB_view = gsl_vector_subvector (dw, input_size * output_size,
00355                                                   output_size);
00356 
00357     wnode->Af_view = gsl_matrix_int_view_vector (f, output_size, input_size);
00358     wnode->Bf_view = gsl_vector_int_subvector (f, input_size * output_size,
00359                                                      output_size);
00360 
00361     /* build pointers */
00362     wnode->A = &(wnode->A_view.matrix);
00363     wnode->B = &(wnode->B_view.vector);
00364     wnode->dA = &(wnode->dA_view.matrix);
00365     wnode->dB = &(wnode->dB_view.vector);
00366     wnode->Af = &(wnode->Af_view.matrix);
00367     wnode->Bf = &(wnode->Bf_view.vector);
00368         
00369     return node;
00370 }
00371 
00372 /**
00373  * @brief Initializes the layer's parameters.
00374  * @ingroup gnn_weight_doc
00375  *
00376  * This functions initializes the layer's parameters by values drawn from
00377  * a uniform distribution with mean zero and standard deviation
00378  * \f[ \sigma = (n + 1)^{-\frac{1}{2}}  \f]
00379  * to insure that each of the outputs \f$y_j\f$ are approximately \f$1\f$.
00380  *
00381  * Implementation Note: The numbers are generated by
00382  * \f[ \frac {1}{ \sqrt {12} \sigma } (x - \frac{1}{2}) \f]
00383  * where \f$x\f$ is drawn from a uniform distribution in \f$[0, 1)\f$.
00384  *
00385  * @param layer A pointer to a \ref gnn_weight node.
00386  * @return Returns 0 on success.
00387  */
00388 int
00389 gnn_weight_init (gnn_node *node)
00390 {
00391     int    i;
00392     int    l;
00393     double sigma;
00394     double sqrt12;
00395     double amplitude;
00396     gsl_rng    *r;
00397     gsl_vector *w;
00398 
00399     /* get the random number generator */
00400     r = gnn_get_rng ();
00401 
00402     /* get the parameters */
00403     w = gnn_node_local_get_w (node);
00404     l = w->size;
00405 
00406     /* compute the amplitude of the uniform distribution */
00407     sigma      = 1.0 / sqrt (l);
00408     sqrt12     = sqrt (12.0);
00409     amplitude  = 1.0 / (sqrt12 * sigma);
00410 
00411     /* initialize weights */
00412     for (i=0; i<l; ++i)
00413     {
00414         double random_value;
00415         random_value = amplitude * (gsl_rng_uniform (r) - 0.5);
00416         gsl_vector_set (w, i, random_value);
00417     }
00418 
00419     /* update change */
00420     gnn_node_local_update (node);
00421 
00422     return 0;
00423 }
00424 
00425 /**
00426  * @brief Gets the node's parameter matrix \f$A\f$.
00427  * @ingroup gnn_weight_doc
00428  *
00429  * This function a pointer to the node's internal parameter matrix \f$A\f$.
00430  * Its values can be freely modified, but after changes are made, the
00431  * \ref gnn_node_local_update function should be called.
00432  *
00433  * @param node A pointer to a \ref gnn_weight.
00434  * @return A pointer to a matrix.
00435  */
00436 gsl_matrix *
00437 gnn_weight_get_A (gnn_node *node)
00438 {
00439     gnn_weight *wnode;
00440 
00441     assert (node != NULL);
00442 
00443     gnn_node_local_get_w (node);
00444     wnode = (gnn_weight *) node;
00445     return wnode->A;
00446 }
00447 
00448 /**
00449  * @brief Gets the node's parameter vector \f$B\f$.
00450  * @ingroup gnn_weight_doc
00451  *
00452  * This function a pointer to the node's internal parameter vector \f$B\f$.
00453  * Its values can be freely modified, but after changes are made, the
00454  * \ref gnn_node_local_update function should be called.
00455  *
00456  * @param node A pointer to a \ref gnn_node.
00457  * @return A pointer to a vector.
00458  */
00459 gsl_vector *
00460 gnn_weight_get_B (gnn_node *node)
00461 {
00462     gnn_weight *wnode;
00463 
00464     assert (node != NULL);
00465 
00466     gnn_node_local_get_w (node);
00467     wnode = (gnn_weight *) node;
00468     return wnode->B;
00469 }
00470 
00471 /**
00472  * @brief Gets \f$\frac{\partial E}{\partial A}\f$.
00473  * @ingroup gnn_weight_doc
00474  *
00475  * This function a pointer to the node's internal parameter matrix
00476  * gradient \f$\frac{\partial E}{\partial A}\f$.
00477  * Its values can be freely modified, but after changes are made, the
00478  * \ref gnn_node_local_update function should be called.
00479  *
00480  * @param node A pointer to a \ref gnn_weight.
00481  * @return A pointer to a matrix.
00482  */
00483 gsl_matrix *
00484 gnn_weight_get_dA (gnn_node *node)
00485 {
00486     gnn_weight *wnode;
00487 
00488     assert (node != NULL);
00489 
00490     gnn_node_local_get_dw (node);
00491     wnode = (gnn_weight *) node;
00492     return wnode->dA;
00493 }
00494 
00495 /**
00496  * @brief Gets \f$\frac{\partial E}{\partial B}\f$.
00497  * @ingroup gnn_weight_doc
00498  *
00499  * This function a pointer to the node's internal parameter vector gradient
00500  * \f$\frac{\partial E}{\partial B}\f$.
00501  * Its values can be freely modified, but after changes are made, the
00502  * \ref gnn_node_local_update function should be called.
00503  *
00504  * @param node A pointer to a \ref gnn_node.
00505  * @return A pointer to a vector.
00506  */
00507 gsl_vector *
00508 gnn_weight_get_dB (gnn_node *node)
00509 {
00510     gnn_weight *wnode;
00511 
00512     assert (node != NULL);
00513 
00514     gnn_node_local_get_dw (node);
00515     wnode = (gnn_weight *) node;
00516     return wnode->dB;
00517 }
00518 
00519 /**
00520  * @brief Gets the frozen flags for the matrix \f$A\f$.
00521  * @ingroup gnn_weight_doc
00522  *
00523  * This function returns a pointer to the node's internal matrix frozen flags.
00524  * Its values can be freely modified, but after changes are made, the
00525  * \ref gnn_node_local_update function should be called. Remember that
00526  * these values should be binary, beeing either 0 (free) or 1 (frozen).
00527  *
00528  * @param node A pointer to a \ref gnn_weight.
00529  * @return A pointer to a matrix.
00530  */
00531 gsl_matrix_int *
00532 gnn_weight_get_Af (gnn_node *node)
00533 {
00534     gnn_weight *wnode;
00535 
00536     assert (node != NULL);
00537 
00538     gnn_node_local_get_f (node);
00539     wnode = (gnn_weight *) node;
00540     return wnode->Af;
00541 }
00542 
00543 /**
00544  * @brief Gets the frozen flags for the vector \f$B\f$.
00545  * @ingroup gnn_weight_doc
00546  *
00547  * This function returns a pointer to the node's internal vector frozen flags.
00548  * Its values can be freely modified, but after changes are made, the
00549  * \ref gnn_node_local_update function should be called. Remember that
00550  * these values should be binary, beeing either 0 (free) or 1 (frozen).
00551  *
00552  * @param node A pointer to a \ref gnn_node.
00553  * @return A pointer to a vector.
00554  */
00555 gsl_vector_int *
00556 gnn_weight_get_Bf (gnn_node *node)
00557 {
00558     gnn_weight *wnode;
00559 
00560     assert (node != NULL);
00561 
00562     gnn_node_local_get_f (node);
00563     wnode = (gnn_weight *) node;
00564     return wnode->Bf;
00565 }
00566 
00567 
00568 
00569 /**
00570  * @brief Connectionist Tools for gnn_weight.
00571  * @defgroup gnn_weight_conn gnn_weight Connectionist Tools.
00572  * @ingroup gnn_weight_doc
00573  *
00574  * This set of functions provide the needed funcionality for manipulating a
00575  * \ref gnn_weight node from a Connectionist's point of view. That is, the
00576  * \ref gnn_weight node is viewed as device that maps the inputs to outputs
00577  * via connections of different strength.
00578  *
00579  * The functions in this section allow to establish individual connections
00580  * from a source \f$x_i\f$ to a destiny \f$y_j\f$. The biases are special
00581  * connections whose sources are constant and equal to 1.
00582  */
00583 
00584 
00585 
00586 /**
00587  * @brief Connect an input with an output.
00588  * @ingroup gnn_weight_conn
00589  *
00590  * Connects the i-th input to the j-th output with the given strength and
00591  * unfreezes it.
00592  *
00593  * @param  node A \ref gnn_weight node.
00594  * @param  i    The index of the input.
00595  * @param  j    The index of the output.
00596  * @param  strength The strength of the connection.
00597  * @return Returns 0 if suceeded.
00598  */
00599 int
00600 gnn_weight_connect (gnn_node *node, size_t i, size_t j, double strength)
00601 {
00602     gnn_weight *wnode;
00603 
00604     assert (node != NULL);
00605 
00606     gnn_node_local_get_w (node);
00607     gnn_node_local_get_f (node);
00608     wnode = (gnn_weight *) node;
00609 
00610     /* set the strength and open parameter */
00611     gsl_matrix_set (wnode->A, j, i, strength);
00612     gsl_matrix_int_set (wnode->Af, j, i, 0);
00613     gnn_node_local_update (node);
00614 
00615     return 0;
00616 }
00617 
00618 /**
00619  * @brief Disconnect an input with an output.
00620  * @ingroup gnn_weight_conn
00621  *
00622  * Disconnects the i-th input to the j-th output. The weight will be frozen
00623  * to 0.
00624  *
00625  * @param  node A \ref gnn_weight node.
00626  * @param  i    The index of the input.
00627  * @param  j    The index of the output.
00628  * @return Returns 0 if suceeded.
00629  */
00630 int
00631 gnn_weight_disconnect (gnn_node *node, size_t i, size_t j)
00632 {
00633     gnn_weight *wnode;
00634 
00635     assert (node != NULL);
00636 
00637     gnn_node_local_get_w (node);
00638     gnn_node_local_get_f (node);
00639     wnode = (gnn_weight *) node;
00640 
00641     /* set the strength and freeze parameter */
00642     gsl_matrix_set (wnode->A, j, i, 0.0);
00643     gsl_matrix_int_set (wnode->Af, j, i, 1);
00644     gnn_node_local_update (node);
00645 
00646     return 0;
00647 }
00648 
00649 /**
00650  * @brief Creates a full connection.
00651  * @ingroup gnn_weight_conn
00652  *
00653  * Connects the all input to the outputs with the given strength. The
00654  * connections will be unfrozen.
00655  *
00656  * @param  node A \ref gnn_weight node.
00657  * @param  strength The strength of the connection.
00658  * @return Returns 0 if suceeded.
00659  */
00660 int
00661 gnn_weight_connect_all (gnn_node *node, double strength)
00662 {
00663     gnn_weight *wnode;
00664 
00665     assert (node != NULL);
00666 
00667     gnn_node_local_get_w (node);
00668     gnn_node_local_get_f (node);
00669     wnode = (gnn_weight *) node;
00670 
00671     /* set the strength and open parameter */
00672     gsl_matrix_set_all (wnode->A, strength);
00673     gsl_matrix_int_set_all (wnode->Af, 0);
00674     gnn_node_local_update (node);
00675 
00676     return 0;
00677 }
00678 
00679 /**
00680  * @brief Disconnect all connections.
00681  * @ingroup gnn_weight_conn
00682  *
00683  * Disconnects all connections. The weights will be frozen to 0.
00684  *
00685  * @param  node A \ref gnn_weight node.
00686  * @return Returns 0 if suceeded.
00687  */
00688 int
00689 gnn_weight_disconnect_all (gnn_node *node)
00690 {
00691     gnn_weight *wnode;
00692 
00693     assert (node != NULL);
00694 
00695     gnn_node_local_get_w (node);
00696     gnn_node_local_get_f (node);
00697     wnode = (gnn_weight *) node;
00698 
00699     /* set the strength and freeze parameter */
00700     gsl_matrix_set_zero (wnode->A);
00701     gsl_matrix_int_set_all (wnode->Af, 1);
00702     gnn_node_local_update (node);
00703 
00704     return 0;
00705 }
00706 
00707 /**
00708  * @brief Connects the i-th input to all outputs with given strength.
00709  * @ingroup gnn_weight_conn
00710  *
00711  * Connects the input at index \a i to all outputs with the given strengh
00712  * \a strength. The connections will be unfrozen.
00713  *
00714  * @param  node A \ref gnn_weight node.
00715  * @param  i    The input's index.
00716  * @param  strength The strength of the connections.
00717  * @return Returns 0 if suceeded.
00718  */
00719 int
00720 gnn_weight_connect_input (gnn_node *node, size_t i, double strength)
00721 {
00722     gnn_weight *wnode;
00723     gsl_vector_view Acol;
00724     gsl_vector_int_view Afcol;
00725 
00726     assert (node != NULL);
00727 
00728     /* check that the input index is within bounds */
00729     if (i < 0 || wnode->A->size2 <= i)
00730     {
00731         GSL_ERROR ("the input's index is out of bounds", GSL_EINVAL);
00732     }
00733 
00734     gnn_node_local_get_w (node);
00735     gnn_node_local_get_f (node);
00736     wnode = (gnn_weight *) node;
00737 
00738     /* build the views for the i-th column */
00739     Acol  = gsl_matrix_column (wnode->A, i);
00740     Afcol = gsl_matrix_int_column (wnode->Af, i);
00741 
00742     /* set the strength and open parameter */
00743     gsl_vector_set_all (&(Acol.vector), strength);
00744     gsl_vector_int_set_all (&(Afcol.vector), 0);
00745     gnn_node_local_update (node);
00746 
00747     return 0;
00748 }
00749 
00750 /**
00751  * @brief Disconnect the i-th input from all outputs.
00752  * @ingroup gnn_weight_conn
00753  *
00754  * Disconnects the i-th input from all outputs. The weights will be frozen to 0.
00755  *
00756  * @param  node A \ref gnn_weight node.
00757  * @param  i    The input's index.
00758  * @return Returns 0 if suceeded.
00759  */
00760 int
00761 gnn_weight_disconnect_input (gnn_node *node, size_t i)
00762 {
00763     gnn_weight *wnode;
00764     gsl_vector_view Acol;
00765     gsl_vector_int_view Afcol;
00766 
00767     assert (node != NULL);
00768 
00769     /* check that the input index is within bounds */
00770     if (i < 0 || wnode->A->size2 <= i)
00771     {
00772         GSL_ERROR ("the input's index is out of bounds", GSL_EINVAL);
00773     }
00774 
00775     gnn_node_local_get_w (node);
00776     gnn_node_local_get_f (node);
00777     wnode = (gnn_weight *) node;
00778 
00779     /* build the views for the i-th column */
00780     Acol  = gsl_matrix_column (wnode->A, i);
00781     Afcol = gsl_matrix_int_column (wnode->Af, i);
00782 
00783     /* set the strength and freeze parameter */
00784     gsl_vector_set_zero (&(Acol.vector));
00785     gsl_vector_int_set_all (&(Afcol.vector), 1);
00786     gnn_node_local_update (node);
00787 
00788     return 0;
00789 }
00790 
00791 /**
00792  * @brief Freeze a connection.
00793  * @ingroup gnn_weight_conn
00794  *
00795  * Freezes the connection from the i-th input to he j-th output.
00796  *
00797  * @param  node A \ref gnn_weight node.
00798  * @param  i    The index of the input.
00799  * @param  j    The index of the output.
00800  * @return Returns 0 if suceeded.
00801  */
00802 int
00803 gnn_weight_freeze_connection (gnn_node *node, size_t i, size_t j)
00804 {
00805     gnn_weight *wnode;
00806 
00807     assert (node != NULL);
00808 
00809     gnn_node_local_get_f (node);
00810     wnode = (gnn_weight *) node;
00811 
00812     /* parameter */
00813     gsl_matrix_int_set (wnode->Af, j, i, 1);
00814     gnn_node_local_update (node);
00815 
00816     return 0;
00817 }
00818 
00819 /**
00820  * @brief Unfreeze a connection.
00821  * @ingroup gnn_weight_conn
00822  *
00823  * Unfreezes the connection going from the i-th input to the j-th output.
00824  *
00825  * @param  node A \ref gnn_weight node.
00826  * @param  i    The index of the input.
00827  * @param  j    The index of the output.
00828  * @return Returns 0 if suceeded.
00829  */
00830 int
00831 gnn_weight_unfreeze_connection (gnn_node *node, size_t i, size_t j)
00832 {
00833     gnn_weight *wnode;
00834 
00835     assert (node != NULL);
00836 
00837     gnn_node_local_get_f (node);
00838     wnode = (gnn_weight *) node;
00839 
00840     /* parameter */
00841     gsl_matrix_int_set (wnode->Af, j, i, 0);
00842     gnn_node_local_update (node);
00843 
00844     return 0;
00845 }
00846 
00847 /**
00848  * @brief Freeze all connections.
00849  * @ingroup gnn_weight_conn
00850  *
00851  * Freezes all the connections.
00852  *
00853  * @param  node A \ref gnn_weight node.
00854  * @return Returns 0 if suceeded.
00855  */
00856 int
00857 gnn_weight_freeze_all_connections (gnn_node *node)
00858 {
00859     gnn_weight *wnode;
00860 
00861     assert (node != NULL);
00862 
00863     gnn_node_local_get_f (node);
00864     wnode = (gnn_weight *) node;
00865 
00866     /* parameter */
00867     gsl_matrix_int_set_all (wnode->Af, 1);
00868     gnn_node_local_update (node);
00869 
00870     return 0;
00871 }
00872 
00873 /**
00874  * @brief Unfreeze all connections.
00875  * @ingroup gnn_weight_conn
00876  *
00877  * Unfreezes all the connections.
00878  *
00879  * @param  node A \ref gnn_weight node.
00880  * @return Returns 0 if suceeded.
00881  */
00882 int
00883 gnn_weight_unfreeze_all_connections (gnn_node *node)
00884 {
00885     gnn_weight *wnode;
00886 
00887     assert (node != NULL);
00888 
00889     gnn_node_local_get_f (node);
00890     wnode = (gnn_weight *) node;
00891 
00892     /* parameter */
00893     gsl_matrix_int_set_all (wnode->Af, 0);
00894     gnn_node_local_update (node);
00895 
00896     return 0;
00897 }
00898 
00899 /**
00900  * @brief Connects a bias.
00901  * @ingroup gnn_weight_conn
00902  *
00903  * Connects the bias associated with the j-th output with the given strength.
00904  * The value will be unfrozen.
00905  *
00906  * @param  node     A \ref gnn_weight node.
00907  * @param  j        The index of the output associated to the bias.
00908  * @param  strength The strength of the connection.
00909  * @return Returns 0 if suceeded.
00910  */
00911 int
00912 gnn_weight_connect_bias (gnn_node *node, size_t j, double strength)
00913 {
00914     gnn_weight *wnode;
00915 
00916     assert (node != NULL);
00917 
00918     gnn_node_local_get_w (node);
00919     gnn_node_local_get_f (node);
00920     wnode = (gnn_weight *) node;
00921 
00922     /* set the strength and open parameter */
00923     gsl_vector_set (wnode->B, j, strength);
00924     gsl_vector_int_set (wnode->Bf, j, 0);
00925     gnn_node_local_update (node);
00926 
00927     return 0;
00928 }
00929 
00930 /**
00931  * @brief Disconnects a bias.
00932  * @ingroup gnn_weight_conn
00933  *
00934  * Disconnects the bias associated with the j-th output.
00935  * The value will be frozen to 0.
00936  *
00937  * @param  node     A \ref gnn_weight node.
00938  * @param  j        The index of the output associated to the bias.
00939  * @return Returns 0 if suceeded.
00940  */
00941 int
00942 gnn_weight_disconnect_bias (gnn_node *node, size_t j)
00943 {
00944     gnn_weight *wnode;
00945 
00946     assert (node != NULL);
00947 
00948     gnn_node_local_get_w (node);
00949     gnn_node_local_get_f (node);
00950     wnode = (gnn_weight *) node;
00951 
00952     /* set the strength and freeze parameter */
00953     gsl_vector_set (wnode->B, j, 0.0);
00954     gsl_vector_int_set (wnode->Bf, j, 1);
00955     gnn_node_local_update (node);
00956 
00957     return 0;
00958 }
00959 
00960 /**
00961  * @brief Connects all biases.
00962  * @ingroup gnn_weight_conn
00963  *
00964  * Connects all biases with the given strength.
00965  * All its values will be unfrozen.
00966  *
00967  * @param  node     A \ref gnn_weight node.
00968  * @param  strength The strength of the connection.
00969  * @return Returns 0 if suceeded.
00970  */
00971 int
00972 gnn_weight_connect_all_bias (gnn_node *node, double strength)
00973 {
00974     gnn_weight *wnode;
00975 
00976     assert (node != NULL);
00977 
00978     gnn_node_local_get_w (node);
00979     gnn_node_local_get_f (node);
00980     wnode = (gnn_weight *) node;
00981 
00982     /* set the strength and open parameter */
00983     gsl_vector_set_all (wnode->B, strength);
00984     gsl_vector_int_set_zero (wnode->Bf);
00985     gnn_node_local_update (node);
00986 
00987     return 0;
00988 }
00989 
00990 /**
00991  * @brief Disconnects all biases.
00992  * @ingroup gnn_weight_conn
00993  *
00994  * Disconnects all biases. All its values will be frozen to 0.
00995  *
00996  * @param  node     A \ref gnn_weight node.
00997  * @return Returns 0 if suceeded.
00998  */
00999 int
01000 gnn_weight_disconnect_all_bias (gnn_node *node)
01001 {
01002     gnn_weight *wnode;
01003 
01004     assert (node != NULL);
01005 
01006     gnn_node_local_get_w (node);
01007     gnn_node_local_get_f (node);
01008     wnode = (gnn_weight *) node;
01009 
01010     /* set the strength and freeze parameter */
01011     gsl_vector_set_all (wnode->B, 0.0);
01012     gsl_vector_int_set_all (wnode->Bf, 1);
01013     gnn_node_local_update (node);
01014 
01015     return 0;
01016 }
01017 
01018 /**
01019  * @brief Freeze a bias.
01020  * @ingroup gnn_weight_conn
01021  *
01022  * Freezes the bias associated to the j-th output.
01023  *
01024  * @param  node A \ref gnn_weight node.
01025  * @param  j    The index of the output.
01026  * @return Returns 0 if suceeded.
01027  */
01028 int
01029 gnn_weight_freeze_bias (gnn_node *node, size_t j)
01030 {
01031     gnn_weight *wnode;
01032 
01033     assert (node != NULL);
01034 
01035     gnn_node_local_get_f (node);
01036     wnode = (gnn_weight *) node;
01037 
01038     /* parameter */
01039     gsl_vector_int_set (wnode->Bf, j, 1);
01040     gnn_node_local_update (node);
01041 
01042     return 0;
01043 }
01044 
01045 /**
01046  * @brief Unfreeze a bias.
01047  * @ingroup gnn_weight_conn
01048  *
01049  * Unfreezes the bias associated to the j-th output.
01050  *
01051  * @param  node A \ref gnn_weight node.
01052  * @param  j    The index of the output.
01053  * @return Returns 0 if suceeded.
01054  */
01055 int
01056 gnn_weight_unfreeze_bias (gnn_node *node, size_t j)
01057 {
01058     gnn_weight *wnode;
01059 
01060     assert (node != NULL);
01061 
01062     gnn_node_local_get_f (node);
01063     wnode = (gnn_weight *) node;
01064 
01065     /* parameter */
01066     gsl_vector_int_set (wnode->Bf, j, 0);
01067     gnn_node_local_update (node);
01068 
01069     return 0;
01070 }
01071 
01072 /**
01073  * @brief Freeze all biases.
01074  * @ingroup gnn_weight_conn
01075  *
01076  * Freezes all biases.
01077  *
01078  * @param  node A \ref gnn_weight node.
01079  * @return Returns 0 if suceeded.
01080  */
01081 int
01082 gnn_weight_freeze_all_biases (gnn_node *node)
01083 {
01084     gnn_weight *wnode;
01085 
01086     assert (node != NULL);
01087 
01088     gnn_node_local_get_f (node);
01089     wnode = (gnn_weight *) node;
01090 
01091     /* parameter */
01092     gsl_vector_int_set_all (wnode->Bf, 1);
01093     gnn_node_local_update (node);
01094 
01095     return 0;
01096 }
01097 
01098 /**
01099  * @brief Unfreeze all biases.
01100  * @ingroup gnn_weight_conn
01101  *
01102  * Unfreezes all biases.
01103  *
01104  * @param  node A \ref gnn_weight node.
01105  * @return Returns 0 if suceeded.
01106  */
01107 int
01108 gnn_weight_unfreeze_all_biases (gnn_node *node)
01109 {
01110     gnn_weight *wnode;
01111 
01112     assert (node != NULL);
01113 
01114     gnn_node_local_get_f (node);
01115     wnode = (gnn_weight *) node;
01116 
01117     /* parameter */
01118     gsl_vector_int_set_all (wnode->Bf, 0);
01119     gnn_node_local_update (node);
01120 
01121     return 0;
01122 }
01123 
01124 
01125 
01126 

Generated on Sun Jun 13 20:50:12 2004 for libgnn Gradient Retropropagation Machine Library by doxygen1.2.18