00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072
00073
00074
00075
00076
00077
00078
00079
00080
00081
00082
00083
00084
00085
00086
00087
00088
00089
00090
00091
00092
00093 #include <gsl/gsl_blas.h>
00094 #include "gnn_gcomm.h"
00095
00096
00097
00098
00099
00100
00101
00102 static int
00103 gnn_gcomm_f (gnn_node *node,
00104 const gsl_vector *x,
00105 const gsl_vector *w,
00106 gsl_vector *y);
00107
00108 static int
00109 gnn_gcomm_dx (gnn_node *node,
00110 const gsl_vector *x,
00111 const gsl_vector *w,
00112 const gsl_vector *dy,
00113 gsl_vector *dx);
00114
00115 static int
00116 gnn_gcomm_dw (gnn_node *node,
00117 const gsl_vector *x,
00118 const gsl_vector *w,
00119 const gsl_vector *dy,
00120 gsl_vector *dw);
00121
00122 static void
00123 gnn_gcomm_destroy (gnn_node *node);
00124
00125
00126
00127
00128
00129
00130
00131
00132
00133
00134
00135
00136
00137
00138
00139
00140
00141 static int
00142 gnn_gcomm_f (gnn_node *node,
00143 const gsl_vector *x,
00144 const gsl_vector *w,
00145 gsl_vector *y)
00146 {
00147 size_t k;
00148 gnn_gcomm *gnode;
00149
00150
00151 gnode = (gnn_gcomm *) node;
00152
00153
00154 gsl_vector_memcpy (gnode->x, x);
00155
00156
00157 gnode->sumw = gsl_blas_dnrm2 (w);
00158 for (k=0; k<gnode->rep; ++k)
00159 {
00160 double wk;
00161 double alphak;
00162
00163
00164 wk = gsl_vector_get (w, k);
00165 wk = wk * wk;
00166
00167
00168 alphak = wk / gnode->sumw;
00169 gsl_vector_set (gnode->alpha, k, alphak);
00170 gsl_blas_daxpy (alphak, gnode->X[k], y);
00171 }
00172
00173 return 0;
00174 }
00175
00176
00177
00178
00179
00180
00181
00182
00183
00184
00185
00186
00187 static int
00188 gnn_gcomm_dx (gnn_node *node,
00189 const gsl_vector *x,
00190 const gsl_vector *w,
00191 const gsl_vector *dy,
00192 gsl_vector *dx)
00193 {
00194 size_t k;
00195 gnn_gcomm *gnode;
00196
00197
00198 gnode = (gnn_gcomm *) node;
00199
00200
00201 for (k=0; k<gnode->rep; ++k)
00202 {
00203 double alphak;
00204
00205 alphak = gsl_vector_get (gnode->alpha, k);
00206
00207 gsl_vector_memcpy (gnode->X[k], dy);
00208 gsl_vector_scale (gnode->X[k], alphak);
00209 }
00210
00211
00212 gsl_vector_memcpy (dx, gnode->x);
00213
00214 return 0;
00215 }
00216
00217
00218
00219
00220
00221
00222
00223
00224
00225
00226
00227
00228 static int
00229 gnn_gcomm_dw (gnn_node *node,
00230 const gsl_vector *x,
00231 const gsl_vector *w,
00232 const gsl_vector *dy,
00233 gsl_vector *dw)
00234 {
00235 size_t k;
00236 size_t j;
00237 size_t kp;
00238 gnn_gcomm *gnode;
00239
00240
00241 gnode = (gnn_gcomm *) node;
00242
00243
00244 gsl_vector_set_zero (gnode->x);
00245 gsl_vector_memcpy (gnode->x, x);
00246
00247
00248 for (k=0; k<dw->size; ++k)
00249 {
00250 double dwk;
00251 double wk;
00252 double Sj;
00253
00254 wk = gsl_vector_get (w, k);
00255 dwk = gsl_vector_get (dw, k);
00256 Sj = 0.0;
00257
00258 for (j=0; j<dy->size; ++j)
00259 {
00260 double dyj;
00261 double Skp;
00262
00263 Skp = 0.0;
00264 dyj = gsl_vector_get (dy, j);
00265
00266 for (kp=0; kp<w->size; ++kp)
00267 {
00268 size_t i;
00269 double alphakp;
00270 double dalphakp;
00271 double xi;
00272
00273 i = kp * dy->size + j;
00274 xi = gsl_vector_get (gnode->x, i);
00275
00276 alphakp = gsl_vector_get (gnode->alpha, kp);
00277 dalphakp = (k == kp)? 1.0 - alphakp / gnode->sumw
00278 : - alphakp / gnode->sumw;
00279
00280 Skp += dalphakp * xi;
00281 }
00282
00283 Sj += dyj * 2 * wk * Skp;
00284 }
00285
00286 dwk += Sj;
00287 gsl_vector_set (dw, k, Sj);
00288 }
00289
00290 return 0;
00291 }
00292
00293
00294
00295
00296
00297
00298
00299 static void
00300 gnn_gcomm_destroy (gnn_node *node)
00301 {
00302 gnn_gcomm *gnode;
00303
00304 gnode = (gnn_gcomm *) node;
00305
00306 if (gnode->alpha != NULL)
00307 gsl_vector_free (gnode->alpha);
00308 if (gnode->xbuf != NULL)
00309 gsl_vector_free (gnode->xbuf);
00310 if (gnode->X_view != NULL)
00311 free (gnode->X_view);
00312 if (gnode->X != NULL)
00313 free (gnode->X);
00314 }
00315
00316
00317
00318
00319
00320
00321
00322
00323
00324
00325
00326
00327
00328
00329
00330
00331
00332
00333
00334 gnn_node *
00335 gnn_gcomm_new (size_t input_size, size_t output_size)
00336 {
00337 int status;
00338 size_t k;
00339 gsl_vector *w;
00340 gnn_node *node;
00341 gnn_gcomm *gnode;
00342 size_t param_size;
00343
00344
00345 if (output_size < 1)
00346 {
00347 GSL_ERROR_VAL ("output size should be strictly positive",
00348 GSL_EINVAL, NULL);
00349 }
00350 if (input_size < output_size)
00351 {
00352 GSL_ERROR_VAL ("input size should be greater than the output size",
00353 GSL_EINVAL, NULL);
00354 }
00355
00356
00357 gnode = (gnn_gcomm *) malloc (sizeof (gnn_gcomm));
00358 if (gnode == NULL)
00359 {
00360 GSL_ERROR_VAL ("could not allocate memory for gnn_gcomm node",
00361 GSL_ENOMEM, NULL);
00362 }
00363
00364
00365 node = (gnn_node *) gnode;
00366
00367
00368 status = gnn_node_init (node,
00369 "gnn_gcomm",
00370 gnn_gcomm_f,
00371 gnn_gcomm_dx,
00372 gnn_gcomm_dw,
00373 gnn_gcomm_destroy);
00374 if (status)
00375 {
00376 GSL_ERROR_VAL ("could not initialize gnn_gcomm node",
00377 GSL_EFAILED, NULL);
00378 }
00379
00380
00381 gnode->rep = (input_size-1) / output_size + 1;
00382
00383
00384 param_size = gnode->rep;
00385 status = gnn_node_set_sizes (node, input_size, output_size, param_size);
00386 if (status)
00387 {
00388 GSL_ERROR_VAL ("could not set sizes for gnn_gcomm node",
00389 GSL_EFAILED, NULL);
00390 }
00391
00392
00393 gnode->alpha = gsl_vector_calloc (param_size);
00394
00395
00396 gnode->xbuf = gsl_vector_calloc (gnode->rep * output_size);
00397
00398 gnode->X_view =
00399 (gsl_vector_view *) malloc (sizeof (gsl_vector_view) * gnode->rep);
00400
00401 gnode->X =
00402 (gsl_vector **) malloc (sizeof (gsl_vector *) * gnode->rep);
00403
00404 if (gnode->xbuf == NULL || gnode->X_view == NULL || gnode->X == NULL)
00405 {
00406 gnn_node_destroy (node);
00407 GSL_ERROR_VAL ("could not allocate buffers for gnn_gcomm node",
00408 GSL_ENOMEM, NULL);
00409 }
00410
00411
00412 gnode->x_view = gsl_vector_subvector (gnode->xbuf, 0, input_size);
00413 gnode->x = &(gnode->x_view.vector);
00414
00415 for (k=0; k<gnode->rep; ++k)
00416 {
00417 gnode->X_view[k] =
00418 gsl_vector_subvector (gnode->xbuf, k * output_size, output_size);
00419 gnode->X[k] = &(gnode->X_view[k].vector);
00420
00421 }
00422
00423
00424 gsl_vector_set_all (gnode->alpha, 1.0);
00425 gnn_node_param_set (node, gnode->alpha);
00426
00427 return node;
00428 }
00429
00430