00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059 #include "gnn_affine.h"
00060 #include <math.h>
00061
00062
00063
00064
00065
00066
00067
00068 static int
00069 gnn_affine_f (gnn_node *node,
00070 const gsl_vector *x,
00071 const gsl_vector *w,
00072 gsl_vector *y);
00073
00074 static int
00075 gnn_affine_dx (gnn_node *node,
00076 const gsl_vector *x,
00077 const gsl_vector *w,
00078 const gsl_vector *dy,
00079 gsl_vector *dx);
00080
00081 static int
00082 gnn_affine_dw (gnn_node *node,
00083 const gsl_vector *x,
00084 const gsl_vector *w,
00085 const gsl_vector *dy,
00086 gsl_vector *dw);
00087
00088 static void
00089 gnn_affine_destroy (gnn_node *node);
00090
00091
00092
00093
00094
00095
00096
00097
00098
00099
00100
00101
00102
00103
00104
00105
00106
00107 static int
00108 gnn_affine_f (gnn_node *node,
00109 const gsl_vector *x,
00110 const gsl_vector *w,
00111 gsl_vector *y)
00112 {
00113 size_t i;
00114 size_t size;
00115 gnn_affine *anode;
00116
00117
00118 size = gnn_node_input_get_size (node);
00119
00120
00121 anode = (gnn_affine *) node;
00122
00123
00124 gsl_vector_memcpy (y, x);
00125 gsl_vector_mul (y, anode->a);
00126 gsl_vector_add (y, anode->b);
00127
00128 return 0;
00129 }
00130
00131
00132
00133
00134
00135
00136
00137
00138
00139
00140
00141
00142 static int
00143 gnn_affine_dx (gnn_node *node,
00144 const gsl_vector *x,
00145 const gsl_vector *w,
00146 const gsl_vector *dy,
00147 gsl_vector *dx)
00148 {
00149 size_t i;
00150 size_t size;
00151 gnn_affine *anode;
00152
00153
00154 size = gnn_node_input_get_size (node);
00155
00156
00157 anode = (gnn_affine *) node;
00158
00159
00160 gsl_vector_memcpy (dx, anode->a);
00161 gsl_vector_mul (dx, dy);
00162
00163 return 0;
00164 }
00165
00166
00167
00168
00169
00170
00171
00172
00173
00174
00175
00176
00177 static int
00178 gnn_affine_dw (gnn_node *node,
00179 const gsl_vector *x,
00180 const gsl_vector *w,
00181 const gsl_vector *dy,
00182 gsl_vector *dw)
00183 {
00184 size_t i;
00185 size_t size;
00186 gnn_affine *anode;
00187
00188
00189 size = gnn_node_input_get_size (node);
00190
00191
00192 anode = (gnn_affine *) node;
00193
00194
00195 gsl_vector_memcpy (anode->buf, x);
00196 gsl_vector_mul (anode->buf, dy);
00197 gsl_vector_add (anode->da, anode->buf);
00198
00199 gsl_vector_add (anode->db, dy);
00200
00201 return 0;
00202 }
00203
00204
00205
00206
00207
00208
00209
00210 static void
00211 gnn_affine_destroy (gnn_node *node)
00212 {
00213 gnn_affine *anode;
00214
00215 anode = (gnn_affine *) node;
00216
00217 if (anode->buf != NULL)
00218 gsl_vector_free (anode->buf);
00219 }
00220
00221
00222
00223
00224
00225
00226
00227
00228
00229
00230
00231
00232
00233
00234
00235
00236
00237
00238
00239
00240 gnn_node *
00241 gnn_affine_new (int input_size, double a, double b)
00242 {
00243 int status;
00244 gnn_node *node;
00245 gnn_affine *anode;
00246 gsl_vector *w;
00247 gsl_vector *dw;
00248 gsl_vector_int *f;
00249
00250
00251 if (input_size < 1)
00252 {
00253 GSL_ERROR_VAL ("input size should be strictly positive",
00254 GSL_EINVAL, NULL);
00255 }
00256
00257
00258 if (a <= 0.0)
00259 {
00260 GSL_ERROR_VAL ("amplitude factor a should be stricly positive",
00261 GSL_EINVAL, NULL);
00262 }
00263
00264
00265 anode = (gnn_affine *) malloc (sizeof (gnn_affine));
00266 if (anode == NULL)
00267 {
00268 GSL_ERROR_VAL ("could not allocate memory for gnn_affine node",
00269 GSL_ENOMEM, NULL);
00270 }
00271
00272
00273 node = (gnn_node *) anode;
00274
00275
00276 status = gnn_node_init (node,
00277 "gnn_affine",
00278 gnn_affine_f,
00279 gnn_affine_dx,
00280 gnn_affine_dw,
00281 NULL);
00282 if (status)
00283 {
00284 GSL_ERROR_VAL ("could not initialize gnn_affine node",
00285 GSL_EFAILED, NULL);
00286 }
00287
00288 status = gnn_node_set_sizes (node, input_size, input_size, 2 * input_size);
00289 if (status)
00290 {
00291 GSL_ERROR_VAL ("could not set sizes for gnn_affine node",
00292 GSL_EFAILED, NULL);
00293 }
00294
00295
00296 anode->buf = gsl_vector_alloc (input_size);
00297 if (anode->buf == NULL)
00298 {
00299 gnn_node_destroy (node);
00300 GSL_ERROR_VAL ("could not allocate buffer for gnn_affine node",
00301 GSL_ENOMEM, NULL);
00302 }
00303
00304
00305 w = gnn_node_local_get_w (node);
00306 dw = gnn_node_local_get_dw (node);
00307 f = gnn_node_local_get_f (node);
00308
00309 anode->a_view = gsl_vector_subvector (w, 0, input_size);
00310 anode->da_view = gsl_vector_subvector (dw, 0, input_size);
00311 anode->af_view = gsl_vector_int_subvector (f, 0, input_size);
00312
00313 anode->b_view = gsl_vector_subvector (w, input_size, input_size);
00314 anode->db_view = gsl_vector_subvector (dw, input_size, input_size);
00315 anode->bf_view = gsl_vector_int_subvector (f, input_size, input_size);
00316
00317 anode->a = &(anode->a_view.vector);
00318 anode->da = &(anode->da_view.vector);
00319 anode->af = &(anode->af_view.vector);
00320
00321 anode->b = &(anode->b_view.vector);
00322 anode->db = &(anode->db_view.vector);
00323 anode->bf = &(anode->bf_view.vector);
00324
00325
00326 gsl_vector_set_all (anode->a, a);
00327 gsl_vector_set_all (anode->b, b);
00328
00329
00330 gnn_node_local_update (node);
00331
00332 return node;
00333 }
00334
00335
00336
00337
00338
00339
00340
00341
00342
00343
00344
00345 gnn_node *
00346 gnn_affine_standard_new (int input_size)
00347 {
00348 return gnn_affine_new (input_size, 1.0, 0.0);
00349 }
00350
00351